# # add_file "contrib/usher.cc" # # patch "ChangeLog" # from [c472ec862675aabab5b00c7a463564a22747966c] # to [a62a6d882a027cce2c3ce56938c68dd5e5d46be4] # # patch "contrib/usher.cc" # from [] # to [7ffeb1adaf524972368b560d911081fab04e9227] # ======================================================================== --- ChangeLog c472ec862675aabab5b00c7a463564a22747966c +++ ChangeLog a62a6d882a027cce2c3ce56938c68dd5e5d46be4 @@ -1,5 +1,13 @@ 2005-09-16 Timothy Brownawell + * contrib/usher.cc: A simple usher/proxy server. It asks connecting + clients for their include pattern, and then forwards the connection + to an appropriate (as given in a config file) monotone server. Note + that all servers operating behind one usher need to have the same + server key. + +2005-09-16 Timothy Brownawell + * netcmd.{cc,hh}, netsync.cc: new netcmd types: usher_cmd and usher_reply_cmd. They are not included in the HMAC, and do not occur during normal communication. Purpose: running multiple servers ======================================================================== --- contrib/usher.cc +++ contrib/usher.cc 7ffeb1adaf524972368b560d911081fab04e9227 @@ -0,0 +1,500 @@ +// Timothy Brownawell +// GPL v2 +// +// This is an "usher" to allow multiple monotone servers to work from +// the same port. It asks the client for the pattern it wants to sync, +// and then looks up the matching server in a table. It then forwards +// the connection to that server. All servers using the same usher need +// to have the same server key. +// +// This requires cooperation from the client, which means it only works +// for recent (post-0.22) clients. +// +// Usage: usher +// +// is the local port to listen on +// is a file containing lines of +// stem ip-address port-number +// +// A request for a pattern starting with "stem" will be directed to the +// server at : + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +int listenport = 5253; + +char const netsync_version = 5; + +char const * const greeting = " Hello! This is the monotone usher at localhost. What would you like?"; + +char const * const errmsg = "!Sorry, I don't know where to find that."; + + +#undef max +#define max(x,y) ((x) > (y) ? (x) : (y)) + +struct errstr +{ + std::string name; + int err; + errstr(std::string const & s, int e): name(s), err(e) {} +}; + +int tosserr(int ret, std::string const & name) +{ + if (ret == -1) + throw errstr(name, errno); + if (ret < 0) + throw errstr(name, ret); + return ret; +} + +struct record +{ + std::string stem; + std::string addr; + int port; +}; + +std::list servers; + +bool get_server(std::string const & reply, std::string & addr, int & port) +{ + std::list::iterator i; + for (i = servers.begin(); i != servers.end(); ++i) { + if (reply.find(i->stem) == 0) + break; + } + if (i == servers.end()) { + std::cerr<<"no server found for "<port; + addr = i->addr; +// std::cerr<<"server for "< +// byte 0x80 | +// ... +// byte 0xff & +// +// the high bit says that this byte is not the last + +void make_packet(std::string const & msg, char * & pkt, int & size) +{ + size = msg.size(); + char const * txt = msg.c_str(); + char header[6]; + header[0] = netsync_version; + header[1] = 100; + int headersize; + if (size >= 128) { + header[2] = 0x80 | (0x7f & (char)(size+2)); + header[3] = (char)((size+2)>>7); + header[4] = 0x80 | (0x7f & (char)(size)); + header[5] = (char)(size>>7); + headersize = 6; + } else if (size >= 127) { + header[2] = 0x80 | (0x7f & (char)(size+1)); + header[3] = (char)((size+1)>>7); + header[4] = (char)(size); + headersize = 5; + } else { + header[2] = (char)(size+1); + header[3] = (char)(size); + headersize = 4; + } + pkt = new char[headersize + size]; + memcpy(pkt, header, headersize); + memcpy(pkt + headersize, txt, size); + size += headersize; +} + +struct buffer +{ + static int const buf_size = 2048; + static int const buf_reset_size = 1024; + char * ptr; + int readpos; + int writepos; + buffer(): readpos(0), writepos(0) + { + ptr = new char[buf_size]; + } + ~buffer(){delete[] ptr;} + buffer(buffer const & b) + { + ptr = new char[buf_size]; + memcpy(ptr, b.ptr, buf_size); + readpos = b.readpos; + writepos = b.writepos; + } + bool canread(){return writepos > readpos;} + bool canwrite(){return writepos < buf_size;} + void getread(char *& p, int & n) + { + p = ptr + readpos; + n = writepos - readpos; + } + void getwrite(char *& p, int & n) + { + p = ptr + writepos; + n = buf_size - writepos; + } + void fixread(int n) + { + if (n < 0) throw errstr("negative read\n", 0); + readpos += n; + if (readpos == writepos) { + readpos = writepos = 0; + } else if (readpos > buf_reset_size) { + memcpy(ptr, ptr+readpos, writepos-readpos); + writepos -= readpos; + readpos = 0; + } + } + void fixwrite(int n) + { + if (n < 0) throw errstr("negative write\n", 0); + writepos += n; + } +}; + +struct sock +{ + int *s; + operator int(){return s[0];} + sock(int ss) + { + int *x = new int[2]; + s = x; + s[0] = ss; + s[1] = 1; + } + sock(sock const & ss){s = ss.s; s[1]++;} + ~sock(){if (s[1]--) return; ::close(s[0]); delete[] s;} + sock operator=(int ss){s[0]=ss;} + void close() + { + if (s[0] == -1) return; + tosserr(shutdown(s[0], SHUT_RDWR), "shutdown()"); + while (::close(s[0]) < 0) { + if (errno != EINTR) throw errstr("close()", 0); + } + s[0]=-1; + } + bool read_to(buffer & buf) + { + char *p; + int n; + buf.getwrite(p, n);new int[2]; + n = read(s[0], p, n);new int[2]; + if (n < 1) { + close(); + return false; + } else + buf.fixwrite(n); + return true; + } + bool write_from(buffer & buf) + { + char *p; + int n; + buf.getread(p, n);new int[2]; + n = write(s[0], p, n);new int[2]; + if (n < 1) { + close(); + return false; + } else + buf.fixread(n); + return true; + } +}; + +sock start(int port) +{ + sock s = tosserr(socket(AF_INET, SOCK_STREAM, 0), "socket()"); + int yes = 1; + tosserr(setsockopt(s, SOL_SOCKET, SO_REUSEADDR, + &yes, sizeof(yes)), "setsockopt"); + sockaddr_in a; + memset (&a, 0, sizeof (a)); + a.sin_port = htons(port); + a.sin_family = AF_INET; + tosserr(bind(s, (sockaddr *) &a, sizeof(a)), "bind"); + listen(s, 10); + return s; +} + +sock make_outgoing(int port, std::string const & address) +{ + sock s = tosserr(socket(AF_INET, SOCK_STREAM, 0), "socket()"); + + struct sockaddr_in a; + memset(&a, 0, sizeof(a)); + a.sin_family = AF_INET; + a.sin_port = htons(port); + + if (!inet_aton(address.c_str(), (in_addr *) &a.sin_addr.s_addr)) + throw errstr("bad ip address format", 0); + + tosserr(connect(s, (sockaddr *) &a, sizeof (a)), "connect()"); + return s; +} + +bool extract_reply(buffer & buf, std::string & out) +{ + char *p; + int n; + buf.getread(p, n); + if (n < 4) return false; + int b = 2; + unsigned int psize = p[b]; + ++b; + if (psize >=128) { + psize = psize - 128 + ((unsigned int)(p[b])<<7); + ++b; + } + if (n < b+psize) return false; + unsigned int size = p[b]; + ++b; + if (size >=128) { + size = size - 128 + ((unsigned int)(p[b])<<7); + ++b; + } + if (n < b+size) return false; + out.clear(); + out.append(p+b, size); + buf.fixread(b + size); +} + +struct channel +{ + sock client; + sock server; + bool have_routed; + bool no_server; + buffer cbuf; + buffer sbuf; + channel(sock & c): client(c), server(-1), + have_routed(false), no_server(false) + { + char * dat; + int size; + make_packet(greeting, dat, size); + char *p; + int n; + sbuf.getwrite(p, n); + if (n < size) size = n; + memcpy(p, dat, size); + sbuf.fixwrite(size); + delete[] dat; + + client.write_from(sbuf); + } + void add_to_select(int & maxfd, fd_set & rd, fd_set & wr, fd_set & er) + { + int c = client; + int s = server; + + if (c > 0) { + FD_SET(c, &er); + if (cbuf.canwrite()) + FD_SET(c, &rd); + if (sbuf.canread()) + FD_SET(c, &wr); + maxfd = max(maxfd, c); + } + if (s > 0) { + FD_SET(s, &er); + if (sbuf.canwrite()) + FD_SET(s, &rd); + if (cbuf.canread()) + FD_SET(s, &wr); + maxfd = max(maxfd, s); + } + } + bool process_selected(fd_set & rd, fd_set & wr, fd_set & er) + { + int c = client; + int s = server; +/* NB: read oob data before normal reads */ + if (c > 0 && FD_ISSET(c, &er)) { + char d; + errno = 0; + if (recv(c, &d, 1, MSG_OOB) < 1) + client.close(), c = -1; + else + send(s, &d, 1, MSG_OOB); + } + if (s > 0 && FD_ISSET(s, &er)) { + char d; + errno = 0; + if (recv(s, &d, 1, MSG_OOB) < 1) + server.close(), s = -1; + else + send(c, &d, 1, MSG_OOB); + } + + char *p=0; + int n; + + if (c > 0 && FD_ISSET(c, &rd)) { + new int[2]; + if (!client.read_to(cbuf)) c = -1; + new int[2]; + if (!have_routed) { + std::string reply; + if (extract_reply(cbuf, reply)) { + int port; + std::string addr; + if (get_server(reply, addr, port)) { + try { + server = make_outgoing(port, addr); + have_routed = true; + s = server; + } catch (errstr & e) { + std::cerr<<"cannot contact server "< 0 && FD_ISSET(s, &rd)) {new int[2]; + if (!server.read_to(sbuf)) s = -1; + } + + if (c > 0 && FD_ISSET(c, &wr)) {new int[2]; + if (!client.write_from(sbuf)) c = -1; + } + if (s > 0 && FD_ISSET(s, &wr)) {new int[2]; + if (!server.write_from(cbuf)) s = -1; + } + + // close sockets we have nothing more to send to + if (c < 0 && !cbuf.canread()) { + server.close(), s = -1; + } + if ((no_server || have_routed && s < 0) && !sbuf.canread()) { + client.close(), c = -1; + } + } +}; + +int main (int argc, char **argv) +{ + if (argc != 3) { + fprintf (stderr, "Usage\n\tusher \n"); + exit (1); + } + + record rec; + std::ifstream cf(argv[2]); + int pos = 0; + while(cf) { + if (pos == 0) + cf>>rec.stem; + else if (pos == 1) + cf>>rec.addr; + else if (pos == 2) + cf>>rec.port; + else if (pos == 3) { + pos = 0; + servers.push_back(rec); + } + ++pos; + } + + signal (SIGPIPE, SIG_IGN); + + sock h(-1); + try { + h = start(atoi(argv[1])); + } catch (errstr & s) { + std::cerr< channels; + + for (;;) { + fd_set rd, wr, er; + FD_ZERO (&rd); + FD_ZERO (&wr); + FD_ZERO (&er); + FD_SET (h, &rd); + int nfds = h; + channel *newchan = 0; + + for (std::list::iterator i = channels.begin(); + i != channels.end(); ++i) + i->add_to_select(nfds, rd, wr, er); + + int r = select(nfds+1, &rd, &wr, &er, NULL); + + if (r == -1 && errno == EINTR) + continue; + if (r < 0) { + perror ("select()"); + exit (1); + } + if (FD_ISSET(h, &rd)) { + try { + struct sockaddr_in client_address; + unsigned int l = sizeof(client_address); + memset(&client_address, 0, l); + sock cli = tosserr(accept(h, (struct sockaddr *) + &client_address, &l), "accept()"); +// std::cerr<<"connect from "<::iterator i = channels.begin(); + i != channels.end(); ++i) + i->process_selected(rd, wr, er); + if (newchan) { + channels.push_back(*newchan); + delete newchan; + newchan = 0; + } + } + return 0; +}