diff --git a/servefile b/servefile index e36124e..31553f9 100755 --- a/servefile +++ b/servefile @@ -266,8 +266,12 @@ class SecureThreadedHTTPServer(ThreadedHTTPServer): def __init__(self, pubKey, privKey, *args, **kwargs): ThreadedHTTPServer.__init__(self, *args, **kwargs) ctx = SSL.Context(SSL.SSLv23_METHOD) - ctx.use_certificate_file(pubKey) - ctx.use_privatekey_file(privKey) + if type(pubKey) == crypto.X509 and type(privKey) == crypto.PKey: + ctx.use_certificate(pubKey) + ctx.use_privatekey(privKey) + else: + ctx.use_certificate_file(pubKey) + ctx.use_privatekey_file(privKey) self.bsocket = socket.socket(self.address_family, self.socket_type) self.socket = SSL.Connection(ctx, self.bsocket) @@ -298,7 +302,7 @@ class ServeFile(): self.serveMode = serveMode self.dirCreated = False self.useSSL = useSSL - self.certPath = self.keyPath = None + self.cert = self.key = None if self.serveMode not in range(3): self.serveMode = None @@ -334,19 +338,58 @@ class ServeFile(): return ips return None - def setupSSLKeys(self, cert, key): - self.certPath = cert - self.keyPath = key + def setSSLKeys(self, cert, key): + """ Set SSL cert/key. Can be either path to file or pyssl X509/PKey object. """ + self.cert = cert + self.key = key + + def genKeyPair(self): + pkey = crypto.PKey() + pkey.generate_key(crypto.TYPE_RSA, 2048) + + req = crypto.X509Req() + subj = req.get_subject() + subj.CN = "127.0.0.1" + subj.O = "servefile laboratories" + subj.OU = "servefile" + + # generate altnames + altNames = [] + for ip in self.getIPs() + ["127.0.0.1"]: + altNames.append("IP:%s" % ip) + altNames.append("DNS:localhost") + ext = crypto.X509Extension("subjectAltName", False, ",".join(altNames)) + req.add_extensions([ext]) + + req.set_pubkey(pkey) + req.sign(pkey, "sha1") + + cert = crypto.X509() + # some browsers complain if they see a cert from the same authority + # with the same serial ==> we just use the seconds as serial. + cert.set_serial_number(int(time.time())) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(365*24*60*60) + cert.set_issuer(req.get_subject()) + cert.set_subject(req.get_subject()) + cert.add_extensions([ext]) + cert.set_pubkey(req.get_pubkey()) + cert.sign(pkey, "sha1") + + self.cert = cert + self.key = pkey def _getCert(self): - return self.certPath + return self.cert def _getKey(self): - return self.keyPath + return self.key def _createServer(self, handler): server = None if self.useSSL: + if not self._getKey(): + self.genKeyPair() server = SecureThreadedHTTPServer(self._getCert(), self._getKey(), ('', self.port), handler) else: server = ThreadedHTTPServer(('', self.port), handler) @@ -458,7 +501,7 @@ def main(): server = ServeFile(args.target, args.port, mode, args.ssl) if args.ssl and args.key: cert = args.cert or args.key - server.setupSSLKeys(cert, args.key) + server.setSSLKeys(cert, args.key) server.serve() except ServeFileException, e: print e