#include <endian.h>               
#include <stdbool.h>              
#include <stddef.h>               
#include <stdint.h>               
#include <stdio.h>                
#include <stdlib.h>               
#include <string.h>               
 
#include "main.h"        
 
#define MAX_CATEGORIES     1
#define ACL_DENY_SIGNATURE 0xf0000000
#define MAX_ACL_RULE_NUM   100000
#define ACL_RULES_PER_PAGE 32
 
enum { PROTO_FIELD_IPV4, SRC_FIELD_IPV4, DST_FIELD_IPV4, NUM_FIELDS_IPV4 };
 
enum { ACL_IPV4VLAN_PROTO, ACL_IPV4VLAN_SRC, ACL_IPV4VLAN_DST, ACL_IPV4VLAN_NUM };
 
struct acl_classify_t {
    const uint8_t *data_ptrs[BURST_SIZE];
    uint32_t acl_results[BURST_SIZE];
};
 
                                         .rule_size    = CNE_ACL_RULE_SZ(NUM_FIELDS_IPV4),
                                         .max_rule_num = MAX_ACL_RULE_NUM};
 
#define ETH_HDR_LEN         (sizeof(struct cne_ether_hdr))
#define IPV4_PROTO_OFFSET   (offsetof(struct cne_ipv4_hdr, next_proto_id))
#define ACL_DATA_OFFSET     (ETH_HDR_LEN + IPV4_PROTO_OFFSET)
#define SRC_ADDR_OFFSET     (offsetof(struct cne_ipv4_hdr, src_addr))
#define DST_ADDR_OFFSET     (offsetof(struct cne_ipv4_hdr, dst_addr))
#define PROTO_ACL_OFFSET    0
#define SRC_ADDR_ACL_OFFSET (SRC_ADDR_OFFSET - IPV4_PROTO_OFFSET)
#define DST_ADDR_ACL_OFFSET (DST_ADDR_OFFSET - IPV4_PROTO_OFFSET)
 
    .num_fields     = 3,
    .defs           = {
        {
            .type        = CNE_ACL_FIELD_TYPE_BITMASK,
            .size        = sizeof(uint8_t),
            .field_index = PROTO_FIELD_IPV4,
            .input_index = ACL_IPV4VLAN_PROTO,
            .offset      = PROTO_ACL_OFFSET,
        },
        {
            .type        = CNE_ACL_FIELD_TYPE_MASK,
            .size        = sizeof(uint32_t),
            .field_index = SRC_FIELD_IPV4,
            .input_index = ACL_IPV4VLAN_SRC,
            .offset      = SRC_ADDR_ACL_OFFSET,
        },
        {
            .type        = CNE_ACL_FIELD_TYPE_MASK,
            .size        = sizeof(uint32_t),
            .field_index = DST_FIELD_IPV4,
            .input_index = ACL_IPV4VLAN_DST,
            .offset      = DST_ADDR_ACL_OFFSET,
        },
    }
};
 
static struct cne_acl_ctx *ctx;
static pthread_mutex_t ctx_mutex = PTHREAD_MUTEX_INITIALIZER;
 
static struct acl_rule_table {
    uint8_t *rules; 
    size_t len;     
    size_t sz;      
    size_t rule_sz; 
} acl_rules = {.rule_sz = CNE_ACL_RULE_SZ(NUM_FIELDS_IPV4)};
 
struct acl_rule_desc {
    uint32_t src_addr;
    uint8_t src_msk;
    uint32_t dst_addr;
    uint8_t dst_msk;
    bool deny;
};
 
#define ACL_NUM_PAGES(x) (((x) / ACL_RULES_PER_PAGE) + !!((x) % ACL_RULES_PER_PAGE))
 
static struct cne_acl_rule *
acl_tbl_get_rule(struct acl_rule_table *tbl, const size_t idx)
{
    return (
struct cne_acl_rule *)
CNE_PTR_ADD(tbl->rules, idx * tbl->rule_sz);
 
}
 
static int
acl_tbl_resize(struct acl_rule_table *tbl, const size_t new_sz)
{
    size_t newlen;
    uint8_t *newmem;
 
    
    if (new_sz <= tbl->sz)
        return 0;
    newlen = tbl->rule_sz * new_sz;
 
    newmem = realloc(tbl->rules, newlen);
    if (newmem == NULL)
        return -ENOMEM;
    tbl->rules = newmem;
    tbl->sz    = newlen;
    return 0;
}
 
static void
acl_tbl_clear(struct acl_rule_table *tbl)
{
    
    tbl->len = 0;
}
 
static int
acl_tbl_ensure_capacity(struct acl_rule_table *tbl, const size_t sz)
{
    uint64_t new_sz;
 
    
    if (sz <= tbl->sz)
        return 0;
    
    
    
    if (new_sz < sz)
        return -ENOSPC;
    
 
    return acl_tbl_resize(tbl, new_sz);
}
 
static void
acl_tbl_write_rule(struct acl_rule_table *tbl, const size_t idx, const struct acl_rule_desc *r)
{
    struct cne_acl_rule *rule = acl_tbl_get_rule(tbl, idx);
 
    
    rule->data.userdata = idx + (r->deny ? ACL_DENY_SIGNATURE : 1);
    
    rule->data.category_mask = -1;
    
    rule->data.priority = CNE_ACL_MAX_PRIORITY - idx;
    
    rule->field[PROTO_FIELD_IPV4].value.u8      = 0;
    rule->field[PROTO_FIELD_IPV4].mask_range.u8 = 0;
    
    rule->field[SRC_FIELD_IPV4].value.u32     = r->src_addr;
    rule->field[SRC_FIELD_IPV4].mask_range.u8 = r->src_msk;
    
    rule->field[DST_FIELD_IPV4].value.u32     = r->dst_addr;
    rule->field[DST_FIELD_IPV4].mask_range.u8 = r->dst_msk;
}
 
static void
acl_tbl_add_rule(struct acl_rule_table *tbl, const struct acl_rule_desc *rule)
{
    size_t newidx = tbl->len;
 
    acl_tbl_write_rule(tbl, newidx, rule);
    tbl->len++;
}
 
static int
acl_add_rule(const struct acl_rule_desc *rule)
{
    struct acl_rule_table *tbl = &acl_rules;
    size_t newidx              = tbl->len;
    int ret, mret;
 
    mret = pthread_mutex_lock(&ctx_mutex);
    if (mret != 0) {
        CNE_ERR("Mutex lock failed: %s\n", strerror(mret));
        return mret;
    }
 
    ret = acl_tbl_ensure_capacity(tbl, newidx);
    if (ret < 0)
        goto unlock;
 
    acl_tbl_add_rule(tbl, rule);
 
unlock:
    mret = pthread_mutex_unlock(&ctx_mutex);
    if (mret != 0)
        CNE_ERR("Mutex unlock failed: %s\n", strerror(mret));
 
    return ret;
}
 
static int
acl_clear(void)
{
    struct acl_rule_table *tbl = &acl_rules;
    int ret;
 
    ret = pthread_mutex_lock(&ctx_mutex);
    if (ret != 0) {
        CNE_ERR("Mutex lock failed: %s\n", strerror(ret));
        return ret;
    }
 
    acl_tbl_clear(tbl);
 
    ret = pthread_mutex_unlock(&ctx_mutex);
    if (ret != 0)
        CNE_ERR("Mutex unlock failed: %s\n", strerror(ret));
 
    return 0;
}
 
static inline int
validate_ipv4_pkt(
const struct cne_ipv4_hdr *pkt, uint32_t link_len)
 
{
    
 
    
        return -1;
 
    
    
 
    
        return -1;
    
        return -1;
 
    
        return -1;
 
    return 0;
}
 
static uint16_t
populate_acl_classify(
struct acl_classify_t *acl_classify_ctx, 
pktmbuf_t **pkts, 
unsigned int len)
 
{
    unsigned int i;
    uint16_t num = 0;
 
    for (i = 0; i < len; i++) {
 
 
        
        if (validate_ipv4_pkt(ipv4_hdr, m->
data_len) < 0) {
 
            
            continue;
        }
 
        acl_classify_ctx->pkts[num]      = m;
        num++;
    }
    return num;
}
 
int
{
    
    const bool fwd_non_matching = fwd->test == ACL_PERMISSIVE_TEST;
    struct fwd_port *pd         = lport->priv_;
    struct acl_classify_t acl_classify_ctx;
    struct acl_fwd_stats *stats = &pd->acl_stats;
    uint16_t n_pkts, n_filtered, n_permit, n_deny;
    int i;
 
    if (!pd)
        return 0;
 
 
    
    switch (fwd->pkt_api) {
    case XSKDEV_PKT_API:
        break;
    case PKTDEV_PKT_API:
        if (n_pkts == PKTDEV_ADMIN_STATE_DOWN)
            return 0;
        break;
    default:
        n_pkts = 0;
        break;
    }
 
    
    n_filtered = populate_acl_classify(&acl_classify_ctx, pd->rx_mbufs, n_pkts);
 
    stats->acl_prefilter_drop += n_pkts - n_filtered;
 
    
    if (n_filtered == 0)
        return 0;
 
    cne_acl_classify(ctx, acl_classify_ctx.data_ptrs, acl_classify_ctx.acl_results, n_filtered,
 
                     MAX_CATEGORIES);
 
    n_deny   = 0;
    n_permit = 0;
 
    for (i = 0; i < n_filtered; i++) {
        const uint32_t res = acl_classify_ctx.acl_results[i];
        
        const bool deny = (res & ACL_DENY_SIGNATURE) != 0;
        
        const bool permit = (res != 0);
        
        const bool forward = !deny && (fwd_non_matching | permit);
 
        if (forward) {
            uint8_t dst_lport = get_dst_lport(
pktmbuf_mtod(pkt, 
void *));
 
 
            if (!dst)
                
                dst = lport;
 
            n_permit++;
        } else {
            
            n_deny++;
        }
    }
 
    int nb_lports = jcfg_num_lports(fwd->jinfo);
    for (int i = 0; i < nb_lports; i++) {
 
        if (!dst)
            continue;
 
        
    }
 
    stats->acl_deny += n_deny;
    stats->acl_permit += n_permit;
 
    return 0;
}
 
static int
add_init_rules(struct cne_acl_ctx *ctx)
{
#define DST_DENY_RULE_NUM  100  
#define SRC_DENY1_RULE_NUM 15   
#define SRC_DENY2_RULE_NUM 15   
#define ALLOW_RULE_NUM     4096 
    unsigned int i, total_num;
    int ret;
 
    total_num = DST_DENY_RULE_NUM + SRC_DENY1_RULE_NUM + SRC_DENY2_RULE_NUM + ALLOW_RULE_NUM;
 
    
    ret = acl_tbl_ensure_capacity(&acl_rules, total_num);
    if (ret < 0)
        CNE_ERR_RET(
"Failed to allocate ACL rule table: %s\n", strerror(-ret));
 
 
    cne_printf(
"Creating destination IP deny rules:\n");
 
 
    for (i = 0; i < DST_DENY_RULE_NUM; i++) {
        struct acl_rule_desc rule = {0};
 
        rule.dst_msk  = 24;
        rule.deny     = true;
 
        acl_tbl_add_rule(&acl_rules, &rule);
    }
 
 
    for (i = 0; i < SRC_DENY1_RULE_NUM; i++) {
        struct acl_rule_desc rule = {0};
 
        rule.src_addr = 
CNE_IPV4(10 + i, 0, 0, 0);
 
        rule.src_msk  = 8;
        rule.deny     = true;
 
        acl_tbl_add_rule(&acl_rules, &rule);
    }
 
 
    for (i = 0; i < SRC_DENY2_RULE_NUM; i++) {
        struct acl_rule_desc rule = {0};
 
        rule.src_msk  = 24;
        rule.deny     = true;
 
        acl_tbl_add_rule(&acl_rules, &rule);
    }
 
    cne_printf(
"Creating permit IP range remapping rules...\n");
 
    cne_printf(
"   210.0.[0-15].[0-255]/32 -> 110.0.[0-15].[0-255]/32\n");
 
 
    for (i = 0; i < ALLOW_RULE_NUM; i++) {
        const uint8_t octet1      = (uint8_t)(i & 0xFF);
        const uint8_t octet2      = (uint8_t)((i >> 8) & 0xFF);
        struct acl_rule_desc rule = {0};
 
        rule.src_addr = 
CNE_IPV4(210, 0, octet2, octet1);
 
        rule.src_msk  = 32;
 
        rule.dst_addr = 
CNE_IPV4(110, 0, octet2, octet1);
 
        rule.dst_msk  = 32;
 
        acl_tbl_add_rule(&acl_rules, &rule);
    }
 
 
    ret = 
cne_acl_add_rules(ctx, (
const struct cne_acl_rule *)acl_rules.rules, total_num);
 
    if (ret < 0)
        CNE_ERR_RET(
"Failed to add rules: %s\n", strerror(-ret));
 
 
    return 0;
}
 
static int
{
 
        return 0;
 
    return thd->
pause ? 0 : -1;
 
}
 
static bool
all_threads_stopped(struct fwd_info *fwd)
{
    
    return jcfg_thread_foreach(fwd->jinfo, thread_paused, NULL) == 0;
}
 
int
{
    
    acl_clear();
 
    return 0;
}
 
static int
parse_acl_rule(char *buf, struct acl_rule_desc *rule)
{
    const char *src_ip_str, *src_mask_str, *dst_ip_str, *dst_mask_str, *rule_str;
    struct in_addr src_addr, dst_addr;
    long src_mask, dst_mask;
    char *state;
    bool deny;
 
    
    src_ip_str   = strtok_r(buf, "/", &state);
    src_mask_str = strtok_r(NULL, ":", &state);
    dst_ip_str   = strtok_r(NULL, "/", &state);
    dst_mask_str = strtok_r(NULL, ":", &state);
    rule_str     = strtok_r(NULL, ":", &state);
 
    if (src_ip_str == NULL || src_mask_str == NULL || dst_ip_str == NULL || dst_mask_str == NULL ||
        rule_str == NULL)
        return -1;
 
    
    if (inet_aton(src_ip_str, &src_addr) < 0)
        return -1;
    if (inet_aton(dst_ip_str, &dst_addr) < 0)
        return -1;
 
    
    src_mask = strtol(src_mask_str, NULL, 10);
    if (src_mask < 0 || src_mask > 32)
        return -1;
 
    dst_mask = strtol(dst_mask_str, NULL, 10);
    if (dst_mask < 0 || dst_mask > 32)
        return -1;
 
    if (strcasecmp("allow", rule_str) == 0)
        deny = false;
    else if (strcasecmp("deny", rule_str) == 0)
        deny = true;
    else
        return -1;
 
    rule->deny     = deny;
    rule->src_addr = be32toh(src_addr.s_addr);
    rule->src_msk  = src_mask;
    rule->dst_addr = be32toh(dst_addr.s_addr);
    rule->dst_msk  = dst_mask;
 
    return 0;
}
 
int
fwd_acl_add_rule(uds_client_t *c, 
const char *cmd 
__cne_unused, 
const char *params)
 
{
    struct acl_rule_desc rule = {0};
    char *buf                 = NULL;
    int ret;
 
    if (params == NULL)
        goto bad_param;
 
    buf = strdup(params);
    if (buf == NULL) {
        uds_append(c, 
"\"error\":\"Failed to allocate memory\"");
 
        return 0;
    }
 
    ret = parse_acl_rule(buf, &rule);
    
    free(buf);
 
    if (ret < 0)
        goto bad_param;
 
    
    ret = acl_add_rule(&rule);
    if (ret < 0) {
        uds_append(c, 
"\"error\":\"Failed to add ACL rule: %s\"", strerror(-ret));
 
        return 0;
    }
 
    return 0;
 
bad_param:
    uds_append(c, 
"\"error\":\"Command expects parameter: " 
                  "<src IP>/<src mask>:<dst IP>/<dst mask>:<allow|deny>\"");
 
    return 0;
}
 
int
{
    int ret, mret;
    struct fwd_info *fwd = (struct fwd_info *)(c->info->priv);
 
    
    if (!all_threads_stopped(fwd)) {
        uds_append(c, 
"\"error\":\"Not all forwarding threads are stopped\"");
 
        return 0;
    }
 
    mret = pthread_mutex_lock(&ctx_mutex);
    if (mret != 0) {
        CNE_ERR("Mutex lock failed: %s\n", strerror(mret));
        return 0;
    }
 
    
    if (ctx != NULL) {
        ctx = NULL;
    }
    if (ctx == NULL) {
        uds_append(c, 
"\"error\":\"ACL context not initialized\"");
 
        goto unlock;
    }
 
    
    ret = 
cne_acl_add_rules(ctx, (
const struct cne_acl_rule *)acl_rules.rules, acl_rules.len);
 
    if (ret < 0) {
        uds_append(c, 
"\"error\":\"Cannot add ACL rules to context: %s\"", strerror(-ret));
 
        goto unlock;
    }
 
    
    if (ret != 0)
        uds_append(c, 
"\"error\":\"Cannot build ACL context: %s\"", strerror(-ret));
 
 
unlock:
    mret = pthread_mutex_unlock(&ctx_mutex);
    if (mret != 0)
        CNE_ERR("Mutex unlock failed: %s\n", strerror(mret));
    return 0;
}
 
static void
print_acl_info(uds_client_t *c)
{
    uds_append(c, 
"\"num rules\":%zu,", acl_rules.len);
 
    uds_append(c, 
"\"rule pages\":%zu,", ACL_NUM_PAGES(acl_rules.len));
 
    uds_append(c, 
"\"rules per page\":%d", ACL_RULES_PER_PAGE);
 
}
 
static void
print_acl_rule(uds_client_t *c, size_t idx)
{
    struct cne_acl_rule *rule = acl_tbl_get_rule(&acl_rules, idx);
    struct in_addr src_addr, dst_addr;
    uint8_t src_msk, dst_msk;
    char *addr;
    bool deny;
 
    src_addr.s_addr = htobe32(rule->field[SRC_FIELD_IPV4].value.u32);
    src_msk         = rule->field[SRC_FIELD_IPV4].mask_range.u8;
    dst_addr.s_addr = htobe32(rule->field[DST_FIELD_IPV4].value.u32);
    dst_msk         = rule->field[DST_FIELD_IPV4].mask_range.u8;
    deny            = !!(rule->data.userdata & ACL_DENY_SIGNATURE);
 
    addr = inet_ntoa(src_addr);
    uds_append(c, 
"\"src addr\":\"%s/%d\",", addr, src_msk);
 
    addr = inet_ntoa(dst_addr);
    uds_append(c, 
"\"dst addr\":\"%s/%d\",", addr, dst_msk);
 
    uds_append(c, 
"\"type\":\"%s\"", deny ? 
"deny" : 
"allow");
 
}
 
static void
print_acl_rule_page(uds_client_t *c, size_t pg)
{
    size_t start, end, cur;
    if (pg >= ACL_NUM_PAGES(acl_rules.len)) {
        return;
    }
    start = pg * ACL_RULES_PER_PAGE;
    end   = 
CNE_MIN((pg + 1) * ACL_RULES_PER_PAGE, acl_rules.len);
 
    for (cur = start; cur < end; cur++) {
        if (cur != start)
        print_acl_rule(c, cur);
    }
}
 
struct acl_read_param {
    bool is_page; 
    size_t num;   
};
static int
parse_acl_read_param(const char *params, struct acl_read_param *out)
{
    const char *type, *num;
    char *state, *dup, *endp;
    int ret = -EINVAL;
    int64_t parsed;
    bool is_page;
 
    dup = strdup(params);
    if (dup == NULL)
        return -ENOMEM;
 
    type = strtok_r(dup, ":", &state);
    num  = strtok_r(NULL, ":", &state);
    if (type == NULL || num == NULL)
        goto end;
 
    if (strcasecmp(type, "p") == 0)
        is_page = true;
    else if (strcasecmp(type, "r") == 0)
        is_page = false;
    else
        goto end;
 
    parsed = strtoll(num, &endp, 10);
    
    if (*endp != '\0')
        goto end;
    if (parsed < 0)
        goto end;
    out->num     = (size_t)parsed;
    out->is_page = is_page;
 
    
    ret = 0;
end:
    free(dup);
    return ret;
}
 
int
fwd_acl_read(uds_client_t *c, 
const char *cmd 
__cne_unused, 
const char *params)
 
{
    struct acl_read_param prm;
    int ret, mret;
 
    
    mret = pthread_mutex_lock(&ctx_mutex);
    if (mret != 0) {
        CNE_ERR("Mutex lock failed: %s\n", strerror(mret));
        return 0;
    }
 
    
    if (params == NULL) {
        print_acl_info(c);
        goto unlock;
    }
    ret = parse_acl_read_param(params, &prm);
    if (ret < 0) {
        if (ret == -ENOMEM)
            uds_append(c, 
"\"error\":\"Cannot allocate memory\"");
 
        else if (ret == -EINVAL)
            uds_append(c, 
"\"error\":\"Parameter must be: 'p:<page num>' or 'r:<rule num>'\"");
 
        goto unlock;
    }
 
    if (prm.is_page) {
        print_acl_rule_page(c, prm.num);
    } else if (prm.num >= acl_rules.len) {
        uds_append(c, 
"\"error\":\"Wrong ACL rule number\"");
 
    } else {
        print_acl_rule(c, prm.num);
    }
 
unlock:
    mret = pthread_mutex_unlock(&ctx_mutex);
    if (mret != 0)
        CNE_ERR("Mutex unlock failed: %s\n", strerror(mret));
    return 0;
}
 
int
acl_init(struct fwd_info *fwd)
{
    int ret = -1, mret;
 
 
    mret = pthread_mutex_lock(&ctx_mutex);
    if (mret != 0) {
        CNE_ERR("Mutex lock failed: %s\n", strerror(mret));
        return -1;
    }
 
    
    if (ctx == NULL)
 
    
    if (add_init_rules(ctx) < 0)
 
 
    
 
    cne_printf(
"ACL rule table created successfully\n");
 
 
    if (fwd->test == ACL_STRICT_TEST) {
        cne_printf(
"All traffic will be dropped unless matching a permit rule\n");
 
    } else {
        cne_printf(
"All traffic will be forwarded unless matching a deny rule\n");
 
    }
    
    ret = 0;
 
unlock:
    mret = pthread_mutex_unlock(&ctx_mutex);
    if (mret != 0)
        CNE_ERR("Mutex unlock failed: %s\n", strerror(mret));
    return ret;
}
int cne_acl_add_rules(struct cne_acl_ctx *ctx, const struct cne_acl_rule *rules, uint32_t num)
 
void cne_acl_free(struct cne_acl_ctx *ctx)
 
struct cne_acl_ctx * cne_acl_create(const struct cne_acl_param *param)
 
int cne_acl_build(struct cne_acl_ctx *ctx, const struct cne_acl_config *cfg)
 
int cne_acl_classify(const struct cne_acl_ctx *ctx, const uint8_t **data, uint32_t *results, uint32_t num, uint32_t categories)
 
#define CNE_PTR_ADD(ptr, x)
 
static uint64_t cne_align64pow2(uint64_t v)
 
#define CNE_IPV4(a, b, c, d)
 
#define CNE_ERR_GOTO(lbl,...)
 
CNDP_API int cne_printf(const char *fmt,...)
 
CNDP_API jcfg_lport_t * jcfg_lport_by_index(jcfg_info_t *jinfo, int idx)
 
static uint16_t pktdev_rx_burst(uint16_t lport_id, pktmbuf_t **rx_pkts, const uint16_t nb_pkts)
 
static __cne_always_inline void pktmbuf_free(pktmbuf_t *m)
 
#define pktmbuf_mtod(m, t)
 
#define pktmbuf_mtod_offset(m, t, o)
 
CNDP_API uint16_t txbuff_add(txbuff_t *buffer, pktmbuf_t *tx_pkt)
 
static int txbuff_count(txbuff_t *buffer)
 
CNDP_API uint16_t txbuff_flush(txbuff_t *buffer)
 
CNDP_API int uds_append(uds_client_t *client, const char *format,...)
 
CNDP_API __cne_always_inline uint16_t xskdev_rx_burst(xskdev_info_t *xi, void **bufs, uint16_t nb_pkts)