diff --git a/servefile b/servefile index 864dfcc..7976a23 100755 --- a/servefile +++ b/servefile @@ -17,6 +17,7 @@ import urllib import os import posixpath import re +import select import SocketServer import socket from stat import ST_SIZE @@ -584,8 +585,8 @@ def catchSSLErrors(BaseSSLClass): class SecureThreadedHTTPServer(ThreadedHTTPServer): - def __init__(self, pubKey, privKey, *args, **kwargs): - ThreadedHTTPServer.__init__(self, *args, **kwargs) + def __init__(self, pubKey, privKey, server_address, RequestHandlerClass, bind_and_activate=True): + ThreadedHTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) ctx = SSL.Context(SSL.SSLv23_METHOD) if type(pubKey) == crypto.X509 and type(privKey) == crypto.PKey: ctx.use_certificate(pubKey) @@ -597,8 +598,9 @@ class SecureThreadedHTTPServer(ThreadedHTTPServer): self.bsocket = socket.socket(self.address_family, self.socket_type) self.socket = SSL.Connection(ctx, self.bsocket) - self.server_bind() - self.server_activate() + if bind_and_activate: + self.server_bind() + self.server_activate() def shutdown_request(self, request): request.shutdown() @@ -629,11 +631,21 @@ class ServeFile(): self.cert = self.key = None self.auth = None self.maxUploadSize = 0 + self.listenIPv4 = True + self.listenIPv6 = True if self.serveMode not in range(self._NUM_MODES): self.serveMode = None raise ValueError("Unknown serve mode, needs to be MODE_SINGLE, MODE_SINGLETAR, MODE_UPLOAD or MODE_DIRLIST.") + def setIPv4(self, ipv4): + """ En- or disable ipv4 """ + self.listenIPv4 = ipv4 + + def setIPv6(self, ipv6): + """ En- or disable ipv6 """ + self.listenIPv6 = ipv6 + def getIPs(self): """ Get IPs from all interfaces via ip or ifconfig. """ # ip and ifconfig sometimes are located in /sbin/ @@ -660,8 +672,13 @@ class ServeFile(): proc = None if proc: ips = proc.stdout.read().strip().split("\n") - # FIXME: When BaseHTTP supports ipv6 properly, delete this line - ips = filter(lambda ip: ip.find(":") == -1, ips) + + # filter out ips we are not listening on + if not self.listenIPv6: + ips = filter(lambda ip: ip.find(":") == -1, ips) + if not self.listenIPv4: + ips = filter(lambda ip: ip.find(".") == -1, ips) + return ips return None @@ -697,7 +714,7 @@ class ServeFile(): # generate altnames altNames = [] - for ip in self.getIPs() + ["127.0.0.1"]: + for ip in self.getIPs() + ["127.0.0.1", "::1"]: altNames.append("IP:%s" % ip) altNames.append("DNS:localhost") ext = crypto.X509Extension("subjectAltName", False, ",".join(altNames)) @@ -736,19 +753,40 @@ class ServeFile(): raise ServeFileException("User and password both need to be at least one character.") self.auth = base64.b64encode("%s:%s" % (user, password)) - def _createServer(self, handler): + def _createServer(self, handler, withv6=False): + ThreadedHTTPServer.address_family = socket.AF_INET + SecureThreadedHTTPServer.address_family = socket.AF_INET + listenIp = '' server = None + + if withv6: + listenIp = '::' + ThreadedHTTPServer.address_family = socket.AF_INET6 + SecureThreadedHTTPServer.address_family = socket.AF_INET6 + if self.useSSL: if not self._getKey(): self.genKeyPair() - server = SecureThreadedHTTPServer(self._getCert(), self._getKey(), ('', self.port), handler) + server = SecureThreadedHTTPServer(self._getCert(), self._getKey(), (listenIp, self.port), handler, False) else: - server = ThreadedHTTPServer(('', self.port), handler) + server = ThreadedHTTPServer((listenIp, self.port), handler, False) + + if withv6: + server.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + + server.server_bind() + server.server_activate() + return server def serve(self): self.handler = self._confAndFindHandler() - self.server = self._createServer(self.handler) + self.server = [] + + if self.listenIPv4: + self.server.append(self._createServer(self.handler)) + if self.listenIPv6: + self.server.append(self._createServer(self.handler, withv6=True)) if self.serveMode != self.MODE_UPLOAD: print "Serving \"%s\" at port %d." % (self.target, self.port) @@ -763,13 +801,19 @@ class ServeFile(): print "Could not find any addresses." else: for ip in ips: + if ":" in ip: + ip = "[%s]" % ip print "\thttp%s://%s:%d/" % (self.useSSL and "s" or "", ip, self.port) print "" try: - self.server.serve_forever() + while True: + (servers, _, _) = select.select(self.server, [], []) + for server in servers: + server.handle_request() except KeyboardInterrupt: - self.server.socket.close() + for server in self.server: + server.socket.close() # cleanup potential upload directory if self.dirCreated and len(os.listdir(self.target)) == 0: @@ -888,6 +932,10 @@ def main(): parser.add_argument('-c', '--compression', type=str, metavar='method', \ default="none", \ help="Set compression method, only in combination with --tar. Can be one of %s" % ", ".join(TarFileHandler.compressionMethods)) + parser.add_argument('-4', '--ipv4-only', action="store_true", default=False, \ + help="Listen on IPv4 only") + parser.add_argument('-6', '--ipv6-only', action="store_true", default=False, \ + help="Listen on IPv6 only") args = parser.parse_args() maxUploadSize = 0 @@ -948,6 +996,10 @@ def main(): print "Error: Compression mode '%s' is unknown." % TarFileHandler.compression sys.exit(1) + if args.ipv4_only and args.ipv6_only: + print "You can't listen both on IPv4 and IPv6 \"only\"" + sys.exit(1) + mode = None if args.upload: mode = ServeFile.MODE_UPLOAD @@ -971,6 +1023,10 @@ def main(): server.setAuth(user, password) if compression and compression != "none": server.setCompression(compression) + if args.ipv4_only: + server.setIPv6(False) + if args.ipv6_only: + server.setIPv4(False) server.serve() except ServeFileException, e: print e