From aa8062e2ec9e04a3e9eacad7e385c10e834e7c69 Mon Sep 17 00:00:00 2001 From: Chion Tang Date: Wed, 14 Mar 2018 02:08:40 +0000 Subject: feature: hashtable for original src & original tuple --- xt_FULLCONENAT.c | 156 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 55 deletions(-) diff --git a/xt_FULLCONENAT.c b/xt_FULLCONENAT.c index 54e8c97..8e56e21 100644 --- a/xt_FULLCONENAT.c +++ b/xt_FULLCONENAT.c @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include #include #include @@ -10,12 +12,17 @@ #include #include #include +#include #include #include #include #include #include +#define HASH_2(x, y) ((x + y) / 2 * (x + y + 1) + y) + +#define HASHTABLE_BUCKET_BITS 10 + #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0) static inline int nf_ct_netns_get(struct net *net, u8 nfproto) { return 0; } @@ -43,7 +50,10 @@ struct nat_mapping { int ifindex; /* external interface index*/ struct nf_conntrack_tuple original_tuple; - struct hlist_node node; + struct hlist_node node_by_ext_port; + struct hlist_node node_by_original_src; + struct hlist_node node_by_original_tuple; + }; struct nf_ct_net_event { @@ -59,22 +69,34 @@ static LIST_HEAD(nf_ct_net_event_list); static DEFINE_MUTEX(nf_ct_net_event_lock); -static DEFINE_HASHTABLE(mapping_table, 10); +static DEFINE_HASHTABLE(mapping_table_by_ext_port, HASHTABLE_BUCKET_BITS); +static DEFINE_HASHTABLE(mapping_table_by_original_src, HASHTABLE_BUCKET_BITS); +static DEFINE_HASHTABLE(mapping_table_by_original_tuple, HASHTABLE_BUCKET_BITS); static DEFINE_SPINLOCK(fullconenat_lock); -static struct nat_mapping* get_mapping(const uint16_t port, const int ifindex, const int create_new) { - struct nat_mapping *p_current, *p_new; - - hash_for_each_possible(mapping_table, p_current, node, port) { - if (p_current->port == port && p_current->ifindex == ifindex) { - return p_current; - } - } +static unsigned int nf_conntrack_hash_rnd __read_mostly; +static u32 hash_conntrack_raw(const struct nf_conntrack_tuple *tuple, + const struct net *net) { + unsigned int n; + u32 seed; + + get_random_once(&nf_conntrack_hash_rnd, sizeof(nf_conntrack_hash_rnd)); + + /* The direction must be ignored, so we hash everything up to the + * destination ports (which is a multiple of 4) and treat the last + * three bytes manually. + */ + seed = nf_conntrack_hash_rnd ^ net_hash_mix(net); + n = (sizeof(tuple->src) + sizeof(tuple->dst.u3)) / sizeof(u32); + return jhash2((u32 *)tuple, n, seed ^ + (((__force __u16)tuple->dst.u.all << 16) | + tuple->dst.protonum)); +} - if (!create_new) { - return NULL; - } +static struct nat_mapping* allocate_mapping(const struct net *net, const uint16_t port, const __be32 int_addr, const uint16_t int_port, const int ifindex, const struct nf_conntrack_tuple* original_tuple) { + struct nat_mapping *p_new; + u32 hash_tuple, hash_src; p_new = kmalloc(sizeof(struct nat_mapping), GFP_ATOMIC); if (p_new == NULL) { @@ -82,34 +104,54 @@ static struct nat_mapping* get_mapping(const uint16_t port, const int ifindex, c return NULL; } p_new->port = port; - p_new->int_addr = 0; - p_new->int_port = 0; + p_new->int_addr = int_addr; + p_new->int_port = int_port; p_new->ifindex = ifindex; - memset(&p_new->original_tuple, 0, sizeof(struct nf_conntrack_tuple)); + memcpy(&p_new->original_tuple, original_tuple, sizeof(struct nf_conntrack_tuple)); - hash_add(mapping_table, &p_new->node, port); + hash_tuple = hash_conntrack_raw(original_tuple, net); + hash_src = HASH_2(int_addr, (u32)int_port); + + hash_add(mapping_table_by_ext_port, &p_new->node_by_ext_port, port); + hash_add(mapping_table_by_original_tuple, &p_new->node_by_original_tuple, hash_tuple); + hash_add(mapping_table_by_original_src, &p_new->node_by_original_src, hash_src); return p_new; } +static struct nat_mapping* get_mapping_by_ext_port(const uint16_t port, const int ifindex) { + struct nat_mapping *p_current; + + hash_for_each_possible(mapping_table_by_ext_port, p_current, node_by_ext_port, port) { + if (p_current->port == port && p_current->ifindex == ifindex) { + return p_current; + } + } + + return NULL; +} + static struct nat_mapping* get_mapping_by_original_src(const __be32 src_ip, const uint16_t src_port, const int ifindex) { struct nat_mapping *p_current; - int i; - hash_for_each(mapping_table, i, p_current, node) { + u32 hash_src = HASH_2(src_ip, (u32)src_port); + + hash_for_each_possible(mapping_table_by_original_src, p_current, node_by_original_src, hash_src) { if (p_current->int_addr == src_ip && p_current->int_port == src_port && p_current->ifindex == ifindex) { return p_current; } } + return NULL; } -static struct nat_mapping* get_mapping_by_original_tuple(const struct nf_conntrack_tuple* tuple) { +static struct nat_mapping* get_mapping_by_original_tuple(const struct net *net, const struct nf_conntrack_tuple* tuple) { struct nat_mapping *p_current; - int i; - if (tuple == NULL) { + u32 hash_tuple = hash_conntrack_raw(tuple, net); + + if (net == NULL || tuple == NULL) { return NULL; } - hash_for_each(mapping_table, i, p_current, node) { + hash_for_each_possible(mapping_table_by_original_tuple, p_current, node_by_original_tuple, hash_tuple) { if (nf_ct_tuple_equal(&p_current->original_tuple, tuple)) { return p_current; } @@ -117,6 +159,16 @@ static struct nat_mapping* get_mapping_by_original_tuple(const struct nf_conntra return NULL; } +static void kill_mapping(struct nat_mapping *mapping) { + if (mapping == NULL) { + return; + } + hash_del(&mapping->node_by_ext_port); + hash_del(&mapping->node_by_original_src); + hash_del(&mapping->node_by_original_tuple); + kfree(mapping); +} + static void destroy_mappings(void) { struct nat_mapping *p_current; struct hlist_node *tmp; @@ -124,9 +176,8 @@ static void destroy_mappings(void) { spin_lock(&fullconenat_lock); - hash_for_each_safe(mapping_table, i, tmp, p_current, node) { - hash_del(&p_current->node); - kfree(p_current); + hash_for_each_safe(mapping_table_by_ext_port, i, tmp, p_current, node_by_ext_port) { + kill_mapping(p_current); } spin_unlock(&fullconenat_lock); @@ -135,7 +186,7 @@ static void destroy_mappings(void) { /* check if a mapping is valid. * possibly delete and free an invalid mapping. * the mapping should not be used anymore after check_mapping() returns 0. */ -static int check_mapping(struct nat_mapping* mapping, struct net *net, struct nf_conntrack_zone *zone) +static int check_mapping(struct net *net, struct nf_conntrack_zone *zone, struct nat_mapping* mapping) { struct nf_conntrack_tuple_hash *original_tuple_hash; @@ -161,14 +212,14 @@ del_mapping: /* for dying/unconfirmed conntracks, an IPCT_DESTROY event may NOT be fired. * so we manually kill one of those conntracks once we acquire one. */ pr_debug("xt_FULLCONENAT: check_mapping(): kill dying/unconfirmed mapping at ext port %d\n", mapping->port); - hash_del(&mapping->node); - kfree(mapping); + kill_mapping(mapping); return 0; } /* conntrack destroy event callback function */ static int ct_event_cb(unsigned int events, struct nf_ct_event *item) { struct nf_conn *ct; + struct net *net; struct nf_conntrack_tuple *ct_tuple_origin; struct nat_mapping *mapping; uint8_t protonum; @@ -179,6 +230,8 @@ static int ct_event_cb(unsigned int events, struct nf_ct_event *item) { return 0; } + net = nf_ct_net(ct); + /* take the original tuple and find the corresponding mapping */ ct_tuple_origin = &(ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); @@ -189,7 +242,7 @@ static int ct_event_cb(unsigned int events, struct nf_ct_event *item) { spin_lock(&fullconenat_lock); - mapping = get_mapping_by_original_tuple(ct_tuple_origin); + mapping = get_mapping_by_original_tuple(net, ct_tuple_origin); if (mapping == NULL) { spin_unlock(&fullconenat_lock); return 0; @@ -197,8 +250,7 @@ static int ct_event_cb(unsigned int events, struct nf_ct_event *item) { /* then kill it */ pr_debug("xt_FULLCONENAT: ct_event_cb(): kill expired mapping at ext port %d\n", mapping->port); - hash_del(&mapping->node); - kfree(mapping); + kill_mapping(mapping); spin_unlock(&fullconenat_lock); @@ -221,7 +273,7 @@ static __be32 get_device_ip(const struct net_device* dev) { } } -static uint16_t find_appropriate_port(const uint16_t original_port, const struct nf_nat_ipv4_range *range, const int ifindex, struct net *net, struct nf_conntrack_zone *zone) { +static uint16_t find_appropriate_port(struct net *net, struct nf_conntrack_zone *zone, const uint16_t original_port, const int ifindex, const struct nf_nat_ipv4_range *range) { uint16_t min, start, selected, range_size, i; struct nat_mapping* mapping = NULL; @@ -245,8 +297,8 @@ static uint16_t find_appropriate_port(const uint16_t original_port, const struct if ((original_port >= min && original_port <= min + range_size - 1) || !(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED)) { /* 1. try to preserve the port if it's available */ - mapping = get_mapping(original_port, ifindex, 0); - if (mapping == NULL || !(check_mapping(mapping, net, zone))) { + mapping = get_mapping_by_ext_port(original_port, ifindex); + if (mapping == NULL || !(check_mapping(net, zone, mapping))) { return original_port; } } @@ -258,14 +310,18 @@ static uint16_t find_appropriate_port(const uint16_t original_port, const struct for (i = 0; i < range_size; i++) { /* 2. try to find an available port */ selected = min + ((start + i) % range_size); - mapping = get_mapping(selected, ifindex, 0); - if (mapping == NULL || !(check_mapping(mapping, net, zone))) { + mapping = get_mapping_by_ext_port(selected, ifindex); + if (mapping == NULL || !(check_mapping(net, zone, mapping))) { return selected; } } - /* 3. at least we tried. rewrite a previous mapping. */ - return min + start; + /* 3. at least we tried. override a previous mapping. */ + selected = min + start; + mapping = get_mapping_by_ext_port(selected, ifindex); + kill_mapping(mapping); + + return selected; } static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_param *par) @@ -322,12 +378,12 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p spin_lock(&fullconenat_lock); /* find an active mapping based on the inbound port */ - mapping = get_mapping(port, ifindex, 0); + mapping = get_mapping_by_ext_port(port, ifindex); if (mapping == NULL) { spin_unlock(&fullconenat_lock); return ret; } - if (check_mapping(mapping, net, zone)) { + if (check_mapping(net, zone, mapping)) { newrange.flags = NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED; newrange.min_addr.ip = mapping->int_addr; newrange.max_addr.ip = mapping->int_addr; @@ -354,7 +410,7 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p original_port = be16_to_cpu((ct_tuple_origin->src).u.udp.port); src_mapping = get_mapping_by_original_src(ip, original_port, ifindex); - if (src_mapping != NULL && check_mapping(src_mapping, net, zone)) { + if (src_mapping != NULL && check_mapping(net, zone, src_mapping)) { /* outbound nat: if a previously established mapping is active, * we will reuse that mapping. */ @@ -364,7 +420,7 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p newrange.max_proto = newrange.min_proto; } else { - want_port = find_appropriate_port(original_port, range, ifindex, net, zone); + want_port = find_appropriate_port(net, zone, original_port, ifindex, range); newrange.flags = NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED; newrange.min_proto.udp.port = cpu_to_be16(want_port); @@ -390,19 +446,9 @@ static unsigned int fullconenat_tg(struct sk_buff *skb, const struct xt_action_p port = be16_to_cpu((ct_tuple->dst).u.udp.port); /* save the mapping information into our mapping table */ - mapping = get_mapping(port, ifindex, 1); - if (mapping == NULL) { - spin_unlock(&fullconenat_lock); - return ret; - } - mapping->int_addr = ip; - mapping->int_port = original_port; - mapping->ifindex = ifindex; - /* save the original source tuple */ - memcpy(&mapping->original_tuple, ct_tuple_origin, sizeof(struct nf_conntrack_tuple)); - - spin_unlock(&fullconenat_lock); + mapping = allocate_mapping(net, port, ip, original_port, ifindex, ct_tuple_origin); + spin_unlock(&fullconenat_lock); return ret; } -- cgit v1.2.3