/*
 * Copyright (c) 1997 Adrian Sun (asun@zoology.washington.edu)
 * All rights reserved. See COPYRIGHT.
 */

#define USE_TCP_NODELAY
#define USE_WRITEV

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <netdb.h>
#include <sys/types.h>
#include <sys/time.h>
#ifdef USE_WRITEV
#include <sys/uio.h>
#endif
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <signal.h>
#include <syslog.h>

#ifdef TCPWRAP
#include <tcpd.h>
int allow_severity = LOG_INFO;
int deny_severity = LOG_WARNING;
#endif

#include <atalk/dsi.h>
#include <atalk/compat.h>
#include <netatalk/endian.h>
#include "dsi_private.h"

#define min(a,b)  ((a) < (b) ? (a) : (b))

#ifndef DSI_TCPMAXPEND
#define DSI_TCPMAXPEND      5       /* max # of pending connections */
#endif

#ifndef DSI_TCPTIMEOUT
#define DSI_TCPTIMEOUT      120     /* timeout in seconds for connections */
#endif

#define AFPOVERTCP_PORT 548

/* FIXME/SOCKLEN_T: socklen_t is a unix98 feature. */
#ifndef SOCKLEN_T
#define SOCKLEN_T unsigned int
#endif

/* write raw data. return actual bytes read. checks against EINTR
 * aren't necessary if all of the signals have SA_RESTART
 * specified. */
static size_t dsi_tcp_write(DSI *dsi, void *data, size_t length)
{
  size_t written;
  ssize_t len;

  written = 0;
  while (written < length) {
    if ((len = write(dsi->socket, (u_int8_t *) data + written,
		     length - written)) == -1 && errno == EINTR)
      continue;

    if (len > 0)
      written += len;
    else
      break;
  }

  return written;
}

/* read raw data. return actual bytes read. this will wait until 
 * it gets length bytes */
static size_t dsi_tcp_read(DSI *dsi, void *data, size_t length)
{
  size_t stored;
  ssize_t len;
  
  stored = 0;
  while (stored < length) {
    if ((len = read(dsi->socket, (u_int8_t *) data + stored, 
		    length - stored)) == -1 && errno == EINTR)
      continue;

    if (len > 0)
      stored += len;
    else
      break;
  }

  return stored;
}


/* write data. 0 on failure. this assumes that dsi_len will never
 * cause an overflow in the data buffer. */
static int dsi_tcp_send(DSI *dsi, void *buf, size_t length)
{
  char block[DSI_BLOCKSIZ];
  sigset_t oldset;
#ifdef USE_WRITEV
  struct iovec iov[2];
  size_t  towrite;
  ssize_t len;
#endif

  block[0] = dsi->header.dsi_flags;
  block[1] = dsi->header.dsi_command;
  memcpy(block + 2, &dsi->header.dsi_requestID, 
	 sizeof(dsi->header.dsi_requestID));
  memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
  memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
  memcpy(block + 12, &dsi->header.dsi_reserved,
	 sizeof(dsi->header.dsi_reserved));

  /* block signals */
  sigprocmask(SIG_BLOCK, &dsi->signals.block, &oldset);

  if (!length) { /* just write the header */
    length = (dsi_tcp_write(dsi, block, sizeof(block)) == sizeof(block));
    sigprocmask(SIG_SETMASK, &oldset, NULL);
    return length; /* really 0 on failure, 1 on success */
  }
  
#ifdef USE_WRITEV
  iov[0].iov_base = block;
  iov[0].iov_len = sizeof(block);
  iov[1].iov_base = buf;
  iov[1].iov_len = length;
  
  towrite = sizeof(block) + length;
  while (towrite > 0) {
    if ((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR)
      continue;
    
    if (len == towrite) /* wrote everything out */
      break;
    else if (len <= 0) { /* eof or error */
      sigprocmask(SIG_SETMASK, &oldset, NULL);
      return 0;
    }
    
    towrite -= len;
    if (towrite > length) { /* skip part of header */
      iov[0].iov_base += len;
      iov[0].iov_len -= len;
    } else { /* skip to data */
      if (iov[0].iov_len) {
	len -= iov[0].iov_len;
	iov[0].iov_len = 0;
      }
      iov[1].iov_base += len;
      iov[1].iov_len -= len;
    }
  }
  
#else
  /* write the header then data */
  if ((dsi_tcp_write(dsi, block, sizeof(block)) != sizeof(block)) ||
      (dsi_tcp_write(dsi, buf, length) != length)) {
    sigprocmask(SIG_SETMASK, &oldset, NULL);
    return 0;
  }
#endif

  sigprocmask(SIG_SETMASK, &oldset, NULL);
  return 1;
}


/* read data. function on success. 0 on failure. data length gets
 * stored in length variable. */
static int dsi_tcp_receive(DSI *dsi, void *buf, size_t ilength,
			   size_t *rlength)
{
  char block[DSI_BLOCKSIZ];

  /* read in the header */
  if (dsi_tcp_read(dsi, block, sizeof(block)) != sizeof(block))
    return 0;

  dsi->header.dsi_flags = block[0];
  dsi->header.dsi_command = block[1];
  memcpy(&dsi->header.dsi_requestID, block + 2, 
	 sizeof(dsi->header.dsi_requestID));
  memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
  memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
  memcpy(&dsi->header.dsi_reserved, block + 12,
	 sizeof(dsi->header.dsi_reserved));
  dsi->clientID = ntohs(dsi->header.dsi_requestID);
  
  /* make sure we don't over-write our buffers. */
  *rlength = min(ntohl(dsi->header.dsi_len), ilength);

  if (dsi_tcp_read(dsi, dsi->commands, *rlength) != *rlength)
      return 0;

  return block[1];
}

static void dsi_tcp_close(DSI *dsi)
{
  if (dsi->socket == -1)
    return;

  close(dsi->socket);
  dsi->socket = -1;
}

/* alarm handler for tcp_open */
static void timeout_handler()
{
  syslog(LOG_ERR, "dsi_tcp_open: connection timed out");
  exit(1);
}

/* accept the socket and do a little sanity checking */
static int dsi_tcp_open(DSI *dsi)
{
  pid_t pid;
  SOCKLEN_T len;

  len = sizeof(dsi->client);
  dsi->socket = accept(dsi->serversock, (struct sockaddr *) &dsi->client,
		       &len);

#ifdef TCPWRAP
  {
    struct request_info req;
    request_init(&req, RQ_DAEMON, dsi->program, RQ_FILE, dsi->socket, NULL);
    fromhost(&req);
    if (!hosts_access(&req)) {
      syslog(deny_severity, "refused connect from %s", eval_client(&req));
      close(dsi->socket);
      dsi->socket = -1;
    }
  }
#endif

  if (dsi->socket < 0)
    return -1;

  if ((pid = fork()) == 0) { /* child */
    static const struct itimerval timer = {{0, 0}, {DSI_TCPTIMEOUT, 0}};
    struct sigaction newact, oldact;
    u_int8_t block[DSI_BLOCKSIZ];
    size_t stored;
    
    /* reset a couple signals */
    signal(SIGTERM, SIG_DFL); 
    signal(SIGHUP, SIG_DFL);

    /* install an alarm to deal with non-responsive connections */
    if (setitimer(ITIMER_REAL, &timer, 0) < 0) {
      syslog(LOG_ERR, "dsi_tcp_open: setitimer: %m");
      exit(1);
    }
    newact.sa_handler = timeout_handler;
    if (sigaction(SIGALRM, &newact, &oldact) < 0) {
	syslog(LOG_ERR, "dsi_tcp_open: sigaction: %m");
	exit(1);
    }
    
    /* read in commands. this is similar to dsi_receive except
     * for the fact that we do some sanity checking to prevent
     * delinquent connections from causing mischief. */
    
    /* read in the first two bytes */
    dsi_tcp_read(dsi, block, 2);
    if ((block[0] > DSIFL_MAX) || (block[1] > DSIFUNC_MAX)) {
      syslog(LOG_ERR, "dsi_tcp_open: invalid header");
      exit(1);
    }      
    
    /* read in the rest of the header */
    stored = 2;
    while (stored < DSI_BLOCKSIZ) {
      len = dsi_tcp_read(dsi, block + stored, sizeof(block) - stored);
      if (len > 0)
	stored += len;
      else {
	syslog(LOG_ERR, "dsi_tcp_open: tcp_read: %m");
	exit(1);
      }
    }
    
    dsi->header.dsi_flags = block[0];
    dsi->header.dsi_command = block[1];
    memcpy(&dsi->header.dsi_requestID, block + 2, 
	   sizeof(dsi->header.dsi_requestID));
    memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
    memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
    memcpy(&dsi->header.dsi_reserved, block + 12,
	   sizeof(dsi->header.dsi_reserved));
    dsi->clientID = ntohs(dsi->header.dsi_requestID);
    
    /* make sure we don't over-write our buffers. */
    dsi->cmdlen = min(ntohl(dsi->header.dsi_len), DSI_CMDSIZ);
    
    stored = 0;
    while (stored < dsi->cmdlen) {
      len = dsi_tcp_read(dsi, dsi->commands + stored, dsi->cmdlen - stored);
      if (len > 0)
	stored += len;
      else {
	syslog(LOG_ERR, "dsi_tcp_open: tcp_read: %m");
	exit(1);
      }
    }
    
    /* restore signal */
    sigaction(SIGALRM, &oldact, 0);
  }
  
  /* send back our pid */
  return pid;
}

/* this needs to accept passed in addresses */
int dsi_tcp_init(DSI *dsi, const char *hostname, const char *address,
		 const u_int16_t ipport)
{
  struct servent     *service;
  struct hostent     *host;
  int                port;

  dsi->protocol = DSI_TCPIP;

  /* create a socket */
  if ((dsi->serversock = socket(PF_INET, SOCK_STREAM, 0)) < 0)
    return 0;
      
  /* find port */
  if (ipport)
    port = htons(ipport);
  else if (service = getservbyname("afpovertcp", "tcp"))
    port = service->s_port;
  else
    port = htons(AFPOVERTCP_PORT);

  /* find address */
  if (!address) 
    dsi->server.sin_addr.s_addr = htonl(INADDR_ANY);
  else if (inet_aton(address, &dsi->server.sin_addr) == 0) {
    syslog(LOG_INFO, "dsi_tcp: invalid address (%s)", address);
    return 0;
  }

  dsi->server.sin_family = AF_INET;
  dsi->server.sin_port = port;

  /* this deals w/ quick close/opens */    
#ifdef SO_REUSEADDR
  port = 1;
  setsockopt(dsi->serversock, SOL_SOCKET, SO_REUSEADDR, &port, sizeof(port));
#endif

#ifdef USE_TCP_NODELAY 
#ifndef SOL_TCP
#define SOL_TCP IPPROTO_TCP
#endif

  port = 1;
  setsockopt(dsi->serversock, SOL_TCP, TCP_NODELAY, &port, sizeof(port));
#endif

  /* now, bind the socket and set it up for listening */
  if ((bind(dsi->serversock, (struct sockaddr *) &dsi->server, 
	    sizeof(dsi->server)) < 0) || 
      (listen(dsi->serversock, DSI_TCPMAXPEND) < 0)) {
    close(dsi->serversock);
    return 0;
  }

  /* get real address for GetStatus */
  if (!address) {
    if (host = gethostbyname(hostname)) 
      dsi->server.sin_addr.s_addr = ((struct in_addr *) host->h_addr)->s_addr;
    else
      syslog(LOG_INFO, "dsi_tcp (Chooser will not select afp/tcp): %m");
  }

  /* everything's set up. now point protocol specific functions to 
   * tcp versions */
  dsi->proto_open = dsi_tcp_open;
  dsi->proto_close = dsi_tcp_close;
  dsi->proto_send = dsi_tcp_send;
  dsi->proto_receive = dsi_tcp_receive;
  dsi->proto_raw_write = dsi_tcp_write;
  dsi->proto_raw_read = dsi_tcp_read;

  return 1;
}
