/* $Id: sfssd.C,v 1.33 2001/01/13 19:46:13 dm Exp $ */

/*
 *
 * Copyright (C) 1999 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "sfssd.h"
#include "parseopt.h"
#include "rxx.h"
#include <dirent.h>

str configfile;
str revocationdir;
ihash<const vec<str>, sfssrv, &sfssrv::argv, &sfssrv::link> srvtab;
list<server, &server::link> serverlist;

void
sfssrv::setprivs ()
{
  if ((uid || gid) && setgroups (0, NULL))
    fatal ("could not void grouplist: %m\n");
  if (gid && setgid (*gid))
    warn ("could not setgid (%d): %m\n", *gid);
  if (uid && setgid (*uid))
    warn ("could not setuid (%d): %m\n", *uid);
}

void
sfssrv::getpkt (const char *, ssize_t, const sockaddr *)
{
  /* We shouldn't receive anything, so assume this is an EOF */
  warn << "EOF from " << argv[0] << "\n";
  x = NULL;
}

void
sfssrv::launch ()
{
  x = axprt_unix_aspawnv (argv[0], argv, 0, wrap (this, &sfssrv::setprivs));
  if (x)
    x->setrcb (wrap (this, &sfssrv::getpkt));
}

sfssrv::sfssrv (const vec<str> &av)
  : argv (av)
{
  srvtab.insert (this);
}

sfssrv::~sfssrv ()
{
  srvtab.remove (this);
}

bool
extension::covered (const bhash<str> &eh)
{
  for (const str *ep = names.base (); ep < names.lim (); ep++)
    if (!eh[*ep])
      return false;
  return true;
}

bool
extension::covered (const vec<str> &ev)
{
  bhash<str> eh;
  for (const str *ep = ev.base (); ep < ev.lim (); ep++)
    eh.insert (*ep);
  return covered (eh);
}

release::release (u_int32_t r)
  : rel (r)
{
  extlist.insert_head (New extension);
}

release::~release ()
{
  extension *e, *ne;
  for (e = extlist.first; e; e = ne) {
    ne = extlist.next (e);
    delete e;
  }
}

static void
pushext (vec<str> *evp, const str &e)
{
  evp->push_back (e);
}

extension *
release::getext (const vec<str> &ev)
{
  bhash<str> eh;
  for (const str *sp = ev.base (); sp < ev.lim (); sp++)
    eh.insert (*sp);
  for (extension *e = extlist.first; e; e = extlist.next (e))
    if (eh.size () == e->names.size () && e->covered (ev))
      return e;
  extension *e = New extension;
  eh.traverse (wrap (pushext, &e->names));
  extlist.insert_head (e);
  return e;
}

server::server (const str &h, sfs_hash *hid)
  : host (h)
{
  if (hid) {
    hostid.alloc ();
    *hostid = *hid;
  }
  serverlist.insert_head (this);
}

server::~server ()
{
  serverlist.remove (this);
  reltab.deleteall ();
}

release *
server::getrel (u_int32_t r)
{
  release *rp, *ret;
  ret = rp = reltab.root ();
  while (rp) {
    if (r <= rp->rel && rp->rel <= ret->rel)
      ret = rp;
    if (r <= rp->rel)
      rp = reltab.left (rp);
    else
      rp = reltab.right (rp);
  }
  return ret;
}

bool
server::clone (ref<axprt_clone> x, svccb *sbp, const bhash<str> &eh)
{
  sfs_connectarg *arg = sbp->template getarg<sfs_connectarg> ();

  for (release *r = getrel (arg->release); r; r = reltab.next (r))
    for (extension *e = r->extlist.first; e; e = r->extlist.next (e))
      if (e->covered (eh))
	if (sfssrv **srvp = e->srvtab[arg->service]) {
	  sfssrv *srv = *srvp;
	  if (srv->x) {
	    sbp->ignore ();
	    srv->x->clone (x);
	  }
	  else
	    sbp->replyref (sfs_connectres (SFS_TEMPERR));
	  return true;
	}
  return false;
}

static rxx versrx ("^(\\d+)(-(\\d+))?$");
static void
parseconfig ()
{
  str cf = configfile;
  parseargs pa (cf);
  bool errors = false;

  str hostname;
  rpc_ptr<sfs_hash> hostid;
  server *s = NULL;
  release *r = NULL;
  extension *e = NULL;

  int line;
  vec<str> av;
  while (pa.getline (&av, &line)) {
    if (!strcasecmp (av[0], "Server")) {
      if (av.size () != 2) {
	  warn << cf << ":" << line
	       << ": usage: Server {hostname|*}[:hostid]\n";
	  errors = true;
	  continue;
      }
      if (strchr (av[1], ':')) {
	hostid.alloc ();
	if (!sfs_parsepath (av[1], &hostname, hostid)) {
	  warn << cf << ":" << line << ": bad hostname/hostid\n";
	  errors = true;
	  continue;
	}
      }
      else {
	hostid.clear ();
	if (av[1] == "*")
	  hostname = myname ();
	else
	  hostname = av[1];
      }

      for (s = serverlist.first; s; s = serverlist.next (s))
	if (hostname == s->host
	    && ((hostid && s->hostid && *hostid == *s->hostid)
		|| (!hostid && !s->hostid)))
	  break;
      if (!s)
	s = New server (hostname, hostid);
      r = NULL;
      e = NULL;
    }
    else if (!strcasecmp (av[0], "Release")) {
      static rxx relrx ("^(\\d+)\\.(\\d\\d?)$");
      if (av.size () != 2 || (!relrx.search (av[1]) && av[1] != "*")) {
	warn << cf << ":" << line << ": usage Release { N.NN | * }\n";
	errors = true;
	r = NULL;
	continue;
      }
      if (!s) {
	warn << cf << ":" << line << ": Release must follow Server\n";
	errors = true;
	r = NULL;
	continue;
      }
      u_int32_t rel;
      if (av[1] == "*")
	rel = 0xffffffff;
      else
	rel = strtoi64 (relrx[1]) * 100 + strtoi64 (relrx[2]);
      r = s->reltab[rel];
      if (!r)
	s->reltab.insert ((r = New release (rel)));
      for (e = r->extlist.first; r->extlist.next (e); e = r->extlist.next (e))
	;
    }
    else if (!strcasecmp (av[0], "Extensions")) {
      av.pop_front ();
      e = r->getext (av);
    }
    else if (!strcasecmp (av[0], "Service")) {
      u_int32_t snum;
      if (av.size () < 3 || !convertint (av[1], &snum)) {
	warn << cf << ":" << line << ": usage: Service num prog [arg ...]\n";
	errors = true;
	continue;
      }
      if (!e) {
	warn << cf << ":" << line
	     << ": Service must follow Release or Extensions\n";
	errors = true;
	continue;
      }
      if (e->srvtab[snum]) {
	warn << cf << ":" << line
	     << ": Service " << snum << " already defined\n";
	errors = true;
	continue;
      }
      av.pop_front ();
      av.pop_front ();
      av[0] = fix_exec_path (av[0]);
      sfssrv *ss = srvtab[av];
      if (!ss)
	ss = New sfssrv (av);
      e->srvtab.insert (snum, ss);
    }
    else if (!strcasecmp (av[0], "HashCost")) {
      if (av.size () != 2 || !convertint (av[1], &sfs_hashcost)) {
	warn << cf << ":" << line << ": usage: HashCost <nbits>\n";
	errors = true;
      }
      else {
	if (sfs_hashcost > sfs_maxhashcost)
	  sfs_hashcost = sfs_maxhashcost;
	str s (strbuf ("SFS_HASHCOST=%d", sfs_hashcost));
	putenv (const_cast<char *> (s.cstr ())); // XXX - cast for linux putenv
      }
    }
    else if (!strcasecmp (av[0], "RevocationDir")) {
      if (av.size () != 2) {
	warn << cf << ":" << line << ": usage: RevocationDir <directory>\n";
	errors = true;
      }
      else {
	revocationdir = av[1];
      }
    }
    else {
      errors = true;
      warn << cf << ":" << line << ": unknown directive '"
	   << av[0] << "'\n";
    }
  }

  if (errors)
    fatal ("parse errors in configuration file\n");
}

static void
clone (ref<asrv> s, ref<axprt_clone> x, svccb *sbp)
{
  s->setcb (NULL);
  if (!sbp)
    return;
  if (sbp->proc () != SFSPROC_CONNECT) {
    sbp->reject (PROC_UNAVAIL);
    return;
  }

  sfs_connectarg *arg = sbp->template getarg<sfs_connectarg> ();
  bhash<str> eh;
  for (const sfs_extension *ep = arg->extensions.base ();
       ep < arg->extensions.lim (); ep++)
    eh.insert (*ep);

  sfs_pathrevoke cert;
  str rawcert = file2str (revocationdir << "/" << 
			  armor32 (&arg->hostid, sizeof (arg->hostid)));
  if (rawcert && str2xdr (cert, rawcert)) {
    sfs_connectres res(SFS_REDIRECT);
    res.revoke->msg = cert.msg;
    res.revoke->sig = cert.sig;
    sbp->reply (&res);
    return;
  }

  server *srv;
  for (srv = serverlist.first; srv; srv = serverlist.next (srv))
    if (srv->host == arg->name && srv->hostid && *srv->hostid == arg->hostid)
      if (srv->clone (x, sbp, eh))
	return;
      else
	break;
  for (srv = serverlist.first; srv; srv = serverlist.next (srv))
    if (srv->host == arg->name && !srv->hostid)
      if (srv->clone (x, sbp, eh))
	return;
      else
	break;
  for (srv = serverlist.first; srv; srv = serverlist.next (srv))
    if (srv->host == arg->name)
      if (srv->clone (x, sbp, eh))
	return;
  for (srv = serverlist.first; srv; srv = serverlist.next (srv))
    if (srv->clone (x, sbp, eh))
      return;
  sbp->replyref (sfs_connectres (SFS_NOSUCHHOST));
}

static void
newserv (int fd)
{
  sockaddr_in sin;
  bzero (&sin, sizeof (sin));
  socklen_t sinlen = sizeof (sin);
  int nfd = accept (fd, (sockaddr *) &sin, &sinlen);
  if (nfd >= 0) {
    warn ("accepted connection from %s\n", inet_ntoa (sin.sin_addr));
    close_on_exec (nfd);
    ref<axprt_clone> x = axprt_clone::alloc (nfd);
    ref<asrv> s = asrv::alloc (x, sfs_program_1);
    s->setcb (wrap (clone, s, x));
  }
  else if (errno != EAGAIN)
    warn ("accept: %m\n");
}

static void
launchservers ()
{
  int fd = inetsocket (SOCK_STREAM, sfs_port);
  if (fd < 0)
    fatal ("could not bind TCP port %d: %m\n", sfs_port);
  close_on_exec (fd);
  listen (fd, 5);
  fdcb (fd, selread, wrap (newserv, fd));
  srvtab.traverse (&sfssrv::launch);
}

static void
restart ()
{
  warn ("version %s, pid %d, restarted with SIGHUP\n", VERSION, getpid ());
  server *s, *ns;
  for (s = serverlist.first; s; s = ns) {
    ns = serverlist.next (s);
    serverlist.remove (s);
    delete s;
  }
  srvtab.deleteall ();
  parseconfig ();
  srvtab.traverse (&sfssrv::launch);
}

static void
usage ()
{
  warnx << progname << ": [-d] [-f configfile]\n";
  exit (1);
}

int
main (int argc, char **argv)
{
  bool opt_nodaemon = false;
  setprogname (argv[0]);

  int ch;
  while ((ch = getopt (argc, argv, "dxf:")) != -1)
    switch (ch) {
    case 'd':
      opt_nodaemon = true;
      break;
    case 'f':
      if (configfile)
	usage ();
      configfile = optarg;
      break;
    case '?':
    default:
      usage ();
    }
  argc -= optind;
  argv += optind;
  if (argc > 1)
    usage ();

  sfsconst_init ();
  if (!configfile)
    configfile = sfsconst_etcfile_required ("sfssd_config");


  parseconfig ();
  if (!revocationdir)
    revocationdir = sfsdir << "srvrevoke";
  if (!opt_nodaemon && !builddir)
    daemonize ();
  warn ("version %s, pid %d\n", VERSION, getpid ());
  launchservers ();
  sigcb (SIGHUP, wrap (restart));

  amain ();
}
