/*
 Copyright (c) 2003 RIPE

 All Rights Reserved

 Permission to use, copy, modify, and distribute this software and its
 documentation for any purpose and without fee is hereby granted,
 provided that the above copyright notice appear in all copies and that
 both that copyright notice and this permission notice appear in
 supporting documentation, and that the name of the author not be
 used in advertising or publicity pertaining to distribution of the
 software without specific, written prior permission.

 THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING
 ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS; IN NO EVENT SHALL
 AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY
 DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN
 AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

 $Id: server.c,v 1.32 2003/04/16 09:00:48 can Exp $
*/

#include "server.h"

/*
 * Given the username, returns uid and gid.
 */
void getUser(gchar * username, uid_t * uid, gid_t * gid)
{
    struct passwd *passwent;


    passwent = getpwnam(username);
    if (passwent == NULL) {
        local_error("getpwnam failed for %s", username);
    }
    *uid = passwent->pw_uid;
    *gid = passwent->pw_gid;
}

/*
 * given the string, returns if it's a valid one or not
 */
gboolean isIP(gchar * str)
{
    gboolean result = FALSE;
    gint error;
    struct addrinfo hints, *aip;

    bzero(&hints, sizeof(hints));
    hints.ai_flags = AI_PASSIVE;
    hints.ai_socktype = SOCK_STREAM;
    error = getaddrinfo(str, "0", &hints, &aip);
    if (error == 0) {
        result = ((aip->ai_family == PF_INET)
                  || (aip->ai_family == PF_INET6));
        freeaddrinfo(aip);
    }
    return result;
}

/*
 * Given the original query, modify the -V switch
 */
gchar *formProxySwitch(gchar * query, sockInfo * clientAddr,
                       gboolean hasComma)
{
    gchar *result = NULL;
    gchar **splitv, **splitvright;
    gchar *newClientAddr, *qry, *vswitchleft, *vswitchright;

    if ((query == NULL) || (strcmp(query, "") == 0)) {
        return (NULL);
    }
    /* do the mapping */
    newClientAddr = v62v4(clientAddr->nodeName);
    /* clean query */
    qry = g_strdup(query);
    g_strchomp(qry);
    /* if it contains -V, modify it, add own */
    if (hasComma && ((strncmp(qry, "-V", 2)) == 0)) {
        splitv = g_strsplit(qry + 2, ",", 2);
        vswitchleft = g_strstrip(g_strdup(splitv[0]));
        if (splitv[1] == NULL) {
            result = NULL;
        } else {
            splitvright = g_strsplit(splitv[1], " ", 2);
            vswitchright = g_strstrip(g_strdup(splitvright[0]));
            result =
                g_strdup_printf("-V%s@%s@%s@%s,%s %s", myName,
                                clientAddr->nodeName, vswitchleft,
                                newClientAddr, vswitchright,
                                splitvright[1]);
            g_strfreev(splitvright);
        }
        g_strfreev(splitv);
    } else {
        result =
            g_strdup_printf("-V%s@%s,%s %s", myName, clientAddr->nodeName,
                            newClientAddr, qry);
    }
    g_free(newClientAddr);
    g_free(qry);
    return (result);
}

/*
 * reads input from client through socket s, the result goes to str, with
 * length len. timeout and oversized are set when input gets late and more
 * than the allowed amount, respectively. If readAll is true, it also keeps
 * reading after the maximum size is reached.
 */
gboolean readData(sockInfo * s, gchar * str, ssize_t len,
                  gboolean * timeout, gboolean * oversized,
                  gboolean readAll)
{
    struct pollfd fds;
    gchar *c;
    gboolean inputFinished = FALSE, result = TRUE;
    gint pollResult;

    fds.fd = s->s;
    fds.events = POLLIN | POLLPRI;
    *oversized = FALSE;
    *timeout = FALSE;
    c = g_new0(gchar, len * 2); /* actually len, but sacrificed
                                 * memory to confidence for now */
    str[0] = 0;
    while (!inputFinished) {
        if (feof(s->fRead)) {
            inputFinished = TRUE;
        } else {
            pollResult = poll(&fds, 1, 120000); /* timeout is fixed */
            if (pollResult <= 0) {
                inputFinished = TRUE;
                *timeout = TRUE;
                result = FALSE;
            } else {
                if (feof(s->fRead)) {
                    inputFinished = TRUE;
                } else if ((fgets(c, len, s->fRead)) == NULL) {
                    inputFinished = TRUE;
                } else {
                    if ((strlen(c) + strlen(str)) >= len) {
                        inputFinished = TRUE;
                        *oversized = TRUE;
                        g_snprintf(str, (len * 2), "%s%s", str, c);
                    } else {
                        g_snprintf(str, (len * 2), "%s%s", str, c);
                        if ((!readAll) && (c[strlen(c) - 1] == '\n')) {
                            inputFinished = TRUE;
                        }
                    }
                }
            }
        }
    }
    g_free(c);
    return (result);
}

/*
 * given string is sent to socket s. if escaped is TRUE, escapes special
 * characters before printing the output
 * returns FALSE if there is an error writing, TRUE otherwise
 */
gboolean displayMessage(sockInfo * f, gchar * str, gboolean escaped)
{
    gchar *vbuf, *escBuf = NULL;
    static ssize_t forwardLength = -1;
    gboolean result = TRUE;

    if ((str == NULL) || (*str == 0)) {
        return (FALSE);
    }
    if (forwardLength == (-1)) {
        forwardLength = atoi((char *) getConf("MAXLENFORWARD"));
    }
    if (escaped) {
        escBuf = g_strcompress(str);
        vbuf = escBuf;
    } else {
        vbuf = str;
    }
    if ((fputs(vbuf, f->fWrite)) == EOF) {
        result = FALSE;
    }
    if (escaped) {
        g_free(escBuf);
    }
    return (result);
}

/*
 * Checks if the given sw exists in the qry string
 */
gboolean checkSwitch(gchar * qry, gchar sw)
{
    getopt_state_t *gst = NULL;
    gboolean found;
    gchar c, **argv, **argv_copy;
    gint argc, i;

    gst = mg_new(0);
    found = FALSE;
    argv = g_strsplit(qry, " ", -1);
    argc = 0;
    for (i = 0; argv[i] != NULL; i++) {
        if (argv[i][0] != '\0') {
            argc++;
        }
    }
    argv_copy = g_new0(gchar *, argc + 1);
    argc = 0;
    for (i = 0; argv[i] != NULL; i++) {
        if (argv[i][0] != '\0') {
            argv_copy[argc] = argv[i];
            argc++;
        }
    }
    argv_copy[argc] = NULL;
    g_free(argv);
    argv = argv_copy;
    while ((c = mg_getopt(argc, argv, &sw, gst)) != EOF) {
        if (c == sw) {
            found = TRUE;
            break;
        }
    }
    g_free(gst);
    g_strfreev(argv);
    return (found);
}

/*
 * Query qry is sent to socket sock, and the result is sent to socket
 * client.
 */
gboolean forwardQuery(sockInfo * client, gchar * qry, sockInfo * sock)
{
    gchar *c, *cp;
    gint linesBlank;
    gboolean timeout, oversized, result = TRUE;

    if (qry == NULL) {
        return (FALSE);
    }
    c = g_new(char, acceptLength * 2);
    cp = strchr(qry, '\r');
    if (cp != NULL)
        *cp = 0;
    cp = strchr(qry, '\n');
    if (cp != NULL)
        *cp = 0;
    if (!(displayMessage(sock, qry, FALSE))) {
        result = FALSE;
    } else if (!(displayMessage(sock, "\r\n", FALSE))) {
        result = FALSE;
    } else {
        g_message("forwarded [%s] %s", client->nodeName, qry);
        linesBlank = 0;
        while (linesBlank < 3) {
            if (!
                (readData
                 (sock, c, acceptLength, &timeout, &oversized, TRUE))) {
                result = FALSE;
            } else {
                cp = c;
                displayMessage(client, c, FALSE);
                while ((cp != NULL) && (*cp != 0)) {
                    if (*cp == 10) {
                        linesBlank++;
                    } else if (linesBlank < 3) {
                        linesBlank = 0;
                    }
                    cp++;
                }
            }
        }
    }
    g_free(c);
    return (result);
}

/*
 * thread function which handles connections. data should be the newClient
 * struct. user_data is unused.
 */
GFunc server(gpointer data, gpointer user_data)
{
    struct hostent *hp;
    struct sockaddr_in sin;
    static gchar *hostname = NULL, *port = NULL;
    gchar *buf, *modifiedQuery, *denyMessage, *bufesc, *buftemp;
    ssize_t nleft;
    gboolean timeout = FALSE, oversized = FALSE, finished =
        FALSE, persistentModeChecked = FALSE, persistentMode =
        FALSE, hasComma = FALSE;
    sockInfo *newClient, *newServer = NULL;
    gint sock = (-1);

    newClient = (sockInfo *) data;
    denyMessage = connCanConnect(connTrack, newClient->nodeName);
    if (denyMessage != NULL) {
        g_warning
            ("client_denied [%s]: too many concurrent connections",
             newClient->nodeName);
        displayMessage(newClient, denyMessage, TRUE);
        connDisconnect(connTrack, newClient->nodeName);
        sockInfo_free(newClient);
        my_thread_end();
        return NULL;
    }
    my_thread_init();
    buf = g_new(char, acceptLength * 2);
    nleft = acceptLength;
    while (!finished) {
        if (!readData
            (newClient, buf, acceptLength, &timeout, &oversized, FALSE)) {
            finished = TRUE;
        } else if (timeout) {
            finished = TRUE;
        } else if ((strncmp(buf, "-V", 2)) == 0) {
            hasComma = FALSE;
            buftemp = buf + 2;
            while ((*buftemp != 0) && (g_ascii_isspace(*buftemp))) {
                buftemp++;
            }
            while ((*buftemp != 0) && (!(g_ascii_isspace(*buftemp)))) {
                if (*buftemp == ',') {
                    hasComma = TRUE;
                    break;
                }
                buftemp++;
            }
            if ((hasComma)
                && (aclLookup(aclHash, newClient->nodeName) == NULL)) {
                displayMessage(newClient, (gchar *)
                               getConf("ADDRESS_PASSING_FORBIDDEN"), TRUE);
                bufesc = g_strescape(buf, NULL);
                g_message("proxy_denied [%s] %s", newClient->nodeName,
                          bufesc);
                g_free(bufesc);
                finished = TRUE;
            }
        } else if (oversized) {
            displayMessage(newClient,
                           (gchar *) getConf("INPUT_LINE_TOO_LONG"), TRUE);
            g_message("longquery [%s]", newClient->nodeName);
            finished = TRUE;
        }
        bufesc = g_strescape(buf, NULL);
        g_message("incoming_query [%s] %s", newClient->nodeName, bufesc);
        g_free(bufesc);
        if (!finished) {
            if (sock == (-1)) {
                /* create a socket to the server */
                if (hostname == NULL) {
                    hostname = (gchar *) getConf("FORWARDHOST");
                    port = (gchar *) getConf("FORWARDPORT");
                }
                sock = socket(AF_INET, SOCK_STREAM, 0);
                if (sock == -1) {
                    local_error("socket: %s", strerror(errno));
                }
                hp = gethostbyname(hostname);
                if (hp == NULL || hp->h_addrtype != AF_INET
                    || hp->h_length != 4) {
                    close(sock);
                    local_error("unknown host: %s", hostname);
                }
                sin.sin_family = AF_INET;
                sin.sin_port = htons(atoi(port));
                memcpy(&sin.sin_addr, hp->h_addr, hp->h_length);
                if (connect(sock, (struct sockaddr *) &sin, sizeof(sin)) ==
                    -1) {
                    g_warning("connect: %s", strerror(errno));
                    g_free(buf);
                    connDisconnect(connTrack, newClient->nodeName);
                    sockInfo_free(newClient);
                    my_thread_end();
                    return NULL;
                } else {
                    newServer = sockInfo_new(sock, "::0000");
                }
            }

            /* check if in persistent mode */
            if (!persistentModeChecked) {
                persistentMode = checkSwitch(buf, 'k');
                persistentModeChecked = TRUE;
            }

            /* modify the -V part of the query */
            modifiedQuery = formProxySwitch(buf, newClient, hasComma);
            if (modifiedQuery == NULL) {
                displayMessage(newClient, (gchar *)
                               getConf("INVALID_ADDRESS_PASSING"), TRUE);
                g_message("invalid_addr_pass [%s] %s", newClient->nodeName,
                          buf);
                finished = TRUE;
            } else if (strlen(modifiedQuery) > acceptLength) {
                displayMessage(newClient,
                               (gchar *) getConf("FORWARD_LINE_TOO_LONG"),
                               TRUE);
                g_message("longforward [%s] %s", newClient->nodeName,
                          modifiedQuery);
                finished = TRUE;
            } else {
                /* Forward query to server and forward result to client */
                finished =
                    !forwardQuery(newClient, modifiedQuery, newServer);

                /* We exit right away if not in persistent mode */
                if (!persistentMode) {
                    finished = TRUE;
                }
            }
            g_free(modifiedQuery);
        }
    }
    g_free(buf);
    connDisconnect(connTrack, newClient->nodeName);
    sockInfo_free(newClient);
    sockInfo_free(newServer);
    my_thread_end();
    return NULL;
}

/*
 * Switches to the user given in configuration
 */
void switchUser(void)
{
    uid_t switchuser;
    gid_t switchgroup;

    getUser((gchar *) getConf("SWITCHUSER"), &switchuser, &switchgroup);
    /* Change to a safer user */
    if ((setgid(switchgroup)) != 0) {
        local_error("setgid: %s", strerror(errno));
    }
    if ((setegid(switchgroup)) != 0) {
        local_error("setegid: %s", strerror(errno));
    }
    if ((setuid(switchuser)) != 0) {
        local_error("setuid: %s", strerror(errno));
    }
    if ((seteuid(switchuser)) != 0) {
        local_error("seteuid: %s", strerror(errno));
    }
}

/*
 * Binds to the given address
 */
gint myBind()
{
    struct addrinfo hints, *aip;
    gint error, sock, sock_opt;

    /* Prepare addrinfo struct for getaddrinfo */
    bzero(&hints, sizeof(hints));
    hints.ai_flags = AI_PASSIVE;
    hints.ai_socktype = SOCK_STREAM;

    /* This should give us values for listening */
    error =
        getaddrinfo((char *) getConf("LISTENHOST"),
                    (char *) getConf("LISTENPORT"), &hints, &aip);
    if (error != 0) {
        local_error("getaddrinfo: %s", gai_strerror(error));
    }
    sock = socket(aip->ai_family, aip->ai_socktype, aip->ai_protocol);
    if (sock < 0) {
        local_error("socket: %s", strerror(errno));
    }
    sock_opt = 1;
    if (setsockopt
        (sock, SOL_SOCKET, SO_REUSEADDR, (void *) &sock_opt,
         sizeof(sock_opt)) == -1) {
        close(sock);
        local_error("setsockopt: %s", strerror(errno));
    }
    if ((bind(sock, aip->ai_addr, aip->ai_addrlen)) == -1) {
        local_error("bind: %s", strerror(errno));
    }
    return (sock);
}

/*
 * The main loop which accepts connections
 */
void startServer(gint sock)
{
    GThreadPool *serverPool;
    int error, new_sock;
    socklen_t faddrlen;
    struct sockaddr_storage faddr;
    gchar hname[NI_MAXHOST], sname[NI_MAXSERV];
    sockInfo *newClient;

    /* Listen on the socket */
    if (listen(sock, atoi((char *) getConf("MAXCONN"))) == -1) {
        close(sock);
        local_error("listen: %s", strerror(errno));
    }
    /* Initialize global read only functions */
    acceptLength = atoi((char *) getConf("MAXLENACCEPT"));
    /* Initialize thread pool for incoming connections */
    serverPool =
        g_thread_pool_new((GFunc) & server, NULL,
                          atoi((char *) getConf("MAXCONN")), FALSE, NULL);
    /* Keep no unused threads */
    g_thread_pool_set_max_unused_threads(0);
    /* this is the main loop */
    for (;;) {
        faddrlen = sizeof(faddr);
        new_sock = accept(sock, (struct sockaddr *) &faddr, &faddrlen);
        if (new_sock == -1) {
            if (errno != EINTR && errno != ECONNABORTED) {
                g_warning("accept: %s", strerror(errno));
            }
            continue;
        }
        error =
            getnameinfo((const struct sockaddr *) &faddr, faddrlen, hname,
                        sizeof(hname), sname, sizeof(sname),
                        NI_NUMERICHOST | NI_NUMERICSERV);
        if (error) {
            g_warning("getnameinfo: %s", gai_strerror(error));
            close(new_sock);
        } else {
            newClient = sockInfo_new(new_sock, hname);
            g_thread_pool_push(serverPool, newClient, NULL);
            g_thread_pool_stop_unused_threads();
        }
    }
    /* end of the main loop */
}
