Compare commits

..

2 Commits

2 changed files with 45 additions and 21 deletions

View File

@ -13,8 +13,8 @@ import argparse
import base64 import base64
import cgi import cgi
import datetime import datetime
import io
import mimetypes import mimetypes
import urllib
import os import os
import re import re
import select import select
@ -27,10 +27,12 @@ import time
try: try:
import BaseHTTPServer import BaseHTTPServer
import SocketServer import SocketServer
from urllib import quote, unquote
except ImportError: except ImportError:
# both have different names in python3 # both have different names in python3
import http.server as BaseHTTPServer import http.server as BaseHTTPServer
import socketserver as SocketServer import socketserver as SocketServer
from urllib.parse import quote, unquote
# only activate SSL if available # only activate SSL if available
HAVE_SSL = False HAVE_SSL = False
@ -56,7 +58,7 @@ class FileBaseHandler(BaseHTTPServer.BaseHTTPRequestHandler):
Returns True if a redirect was issued. """ Returns True if a redirect was issued. """
if not fileName: if not fileName:
fileName = self.fileName fileName = self.fileName
if urllib.unquote(self.path) != "/" + fileName: if unquote(self.path) != "/" + fileName:
self.send_response(302) self.send_response(302)
self.send_header('Location', '/' + fileName) self.send_header('Location', '/' + fileName)
self.end_headers() self.end_headers()
@ -346,12 +348,12 @@ class DirListingHandler(FileBaseHandler):
<p>The requestet URL %s was not found on this server</p> <p>The requestet URL %s was not found on this server</p>
<p><a href="/">Back to /</a> <p><a href="/">Back to /</a>
</body> </body>
</html>""" % self.escapeHTML(urllib.unquote(self.path)) </html>""" % self.escapeHTML(unquote(self.path))
self.send_header("Content-Length", str(len(errorMsg))) self.send_header("Content-Length", str(len(errorMsg)))
self.send_header('Connection', 'close') self.send_header('Connection', 'close')
self.end_headers() self.end_headers()
if not head: if not head:
self.wfile.write(errorMsg) self.wfile.write(errorMsg.encode())
def escapeHTML(self, htmlstr): def escapeHTML(self, htmlstr):
entities = [("<", "&lt;"), (">", "&gt;")] entities = [("<", "&lt;"), (">", "&gt;")]
@ -378,7 +380,7 @@ class DirListingHandler(FileBaseHandler):
<td class="size">%s</td> <td class="size">%s</td>
<td class="type">%s</td> <td class="type">%s</td>
</tr> </tr>
""" % (urllib.quote(item), item, lastModified, fileSize, fileType)) """ % (quote(item), item, lastModified, fileSize, fileType))
def sendDirectoryListing(self, path, head): def sendDirectoryListing(self, path, head):
""" Generate a directorylisting for path and send it """ """ Generate a directorylisting for path and send it """
@ -414,7 +416,7 @@ class DirListingHandler(FileBaseHandler):
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
""" % {'path': os.path.normpath(urllib.unquote(self.path))} """ % {'path': os.path.normpath(unquote(self.path))}
footer = """</tbody></table></div> footer = """</tbody></table></div>
<div class="footer"><a href="http://seba-geek.de/stuff/servefile/">servefile %(version)s</a></div> <div class="footer"><a href="http://seba-geek.de/stuff/servefile/">servefile %(version)s</a></div>
<script> <script>
@ -540,7 +542,7 @@ class DirListingHandler(FileBaseHandler):
self.send_header("Content-Length", str(len(listing))) self.send_header("Content-Length", str(len(listing)))
self.send_header('Connection', 'close') self.send_header('Connection', 'close')
self.end_headers() self.end_headers()
self.wfile.write(listing) self.wfile.write(listing.encode())
def convertSize(self, size): def convertSize(self, size):
for ext in "KMGT": for ext in "KMGT":
@ -552,7 +554,7 @@ class DirListingHandler(FileBaseHandler):
return (size, ext.strip()) return (size, ext.strip())
def getCleanPath(self): def getCleanPath(self):
urlPath = os.path.normpath(urllib.unquote(self.path)).strip("/") urlPath = os.path.normpath(unquote(self.path)).strip("/")
path = os.path.join(self.targetDir, urlPath) path = os.path.join(self.targetDir, urlPath)
return path return path
@ -594,7 +596,8 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
length = self.getContentLength() length = self.getContentLength()
if length < 0: if length < 0:
return return
ctype = self.headers.getheader('Content-Type') print(self.headers)
ctype = self.headers.get('Content-Type')
# check for multipart/form-data. # check for multipart/form-data.
if not (ctype and ctype.lower().startswith("multipart/form-data")): if not (ctype and ctype.lower().startswith("multipart/form-data")):
@ -615,7 +618,7 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
return return
# write file down to disk, send a 200 afterwards # write file down to disk, send a 200 afterwards
target = open(destFileName, "w") target = open(destFileName, "wb")
bytesLeft = length bytesLeft = length
while bytesLeft > 0: while bytesLeft > 0:
bytesToRead = min(self.blockSize, bytesLeft) bytesToRead = min(self.blockSize, bytesLeft)
@ -638,7 +641,7 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
if length < 0: if length < 0:
return return
fileName = urllib.unquote(self.path) fileName = unquote(self.path)
if fileName == "/": if fileName == "/":
# if no filename was given we have to generate one # if no filename was given we have to generate one
fileName = str(time.time()) fileName = str(time.time())
@ -685,7 +688,7 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
self.send_header('Content-Length', str(len(msg))) self.send_header('Content-Length', str(len(msg)))
self.send_header('Connection', 'close') self.send_header('Connection', 'close')
self.end_headers() self.end_headers()
self.wfile.write(msg) self.wfile.write(msg.encode())
def getTargetName(self, fname): def getTargetName(self, fname):
""" Generate a clean and secure filename. """ Generate a clean and secure filename.
@ -769,8 +772,19 @@ class SecureThreadedHTTPServer(ThreadedHTTPServer):
class SecureHandler(): class SecureHandler():
def setup(self): def setup(self):
self.connection = self.request self.connection = self.request
self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
self.wfile = socket._fileobject(self.request, "wb", self.wbufsize) if sys.version_info[0] > 2:
# python3 SocketIO (replacement for socket._fileobject)
raw_read_sock = socket.SocketIO(self.request, 'rb')
raw_write_sock = socket.SocketIO(self.request, 'wb')
rbufsize = self.rbufsize > 0 and self.rbufsize or io.DEFAULT_BUFFER_SIZE
wbufsize = self.wbufsize > 0 and self.wbufsize or io.DEFAULT_BUFFER_SIZE
self.rfile = io.BufferedReader(raw_read_sock, rbufsize)
self.wfile = io.BufferedWriter(raw_write_sock, wbufsize)
else:
# python2 does not have SocketIO
self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
class ServeFileException(Exception): class ServeFileException(Exception):
pass pass
@ -831,7 +845,7 @@ class ServeFile():
# we couldn't find any ip address # we couldn't find any ip address
proc = None proc = None
if proc: if proc:
ips = proc.stdout.read().strip().split("\n") ips = proc.stdout.read().decode().strip().split("\n")
# filter out ips we are not listening on # filter out ips we are not listening on
if not self.listenIPv6: if not self.listenIPv6:
@ -877,7 +891,7 @@ class ServeFile():
for ip in self.getIPs() + ["127.0.0.1", "::1"]: for ip in self.getIPs() + ["127.0.0.1", "::1"]:
altNames.append("IP:%s" % ip) altNames.append("IP:%s" % ip)
altNames.append("DNS:localhost") altNames.append("DNS:localhost")
ext = crypto.X509Extension("subjectAltName", False, ",".join(altNames)) ext = crypto.X509Extension(b"subjectAltName", False, (",".join(altNames)).encode())
req.add_extensions([ext]) req.add_extensions([ext])
req.set_pubkey(pkey) req.set_pubkey(pkey)
@ -901,8 +915,8 @@ class ServeFile():
self.key = pkey self.key = pkey
print("done.") print("done.")
print("SHA1 fingerprint:", cert.digest("sha1")) print("SHA1 fingerprint:", cert.digest("sha1").decode())
print("MD5 fingerprint:", cert.digest("md5")) print("MD5 fingerprint:", cert.digest("md5").decode())
def _getCert(self): def _getCert(self):
return self.cert return self.cert
@ -913,7 +927,7 @@ class ServeFile():
def setAuth(self, user, password, realm=None): def setAuth(self, user, password, realm=None):
if not user or not password: if not user or not password:
raise ServeFileException("User and password both need to be at least one character.") raise ServeFileException("User and password both need to be at least one character.")
self.auth = base64.b64encode("%s:%s" % (user, password)) self.auth = base64.b64encode(("%s:%s" % (user, password)).encode()).decode()
self.authrealm = realm self.authrealm = realm
def _createServer(self, handler, withv6=False): def _createServer(self, handler, withv6=False):
@ -975,7 +989,7 @@ class ServeFile():
else: else:
pwPart = "" pwPart = ""
if self.auth: if self.auth:
pwPart = base64.b64decode(self.auth) + "@" pwPart = base64.b64decode(self.auth).decode() + "@"
for ip in ips: for ip in ips:
if ":" in ip: if ":" in ip:
ip = "[%s]" % ip ip = "[%s]" % ip
@ -1091,7 +1105,7 @@ class AuthenticationHandler():
errorMsg = "<html><head><title>401 - Unauthorized</title></head><body><h1>401 - Unauthorized</h1></body></html>" errorMsg = "<html><head><title>401 - Unauthorized</title></head><body><h1>401 - Unauthorized</h1></body></html>"
self.send_header("Content-Length", str(len(errorMsg))) self.send_header("Content-Length", str(len(errorMsg)))
self.end_headers() self.end_headers()
self.wfile.write(errorMsg) self.wfile.write(errorMsg.encode())
def main(): def main():

View File

@ -274,3 +274,13 @@ def test_https(run_servefile, datadir):
# assert fingerprint # assert fingerprint
urllib3.disable_warnings() urllib3.disable_warnings()
check_download(data, protocol='https', verify=False) check_download(data, protocol='https', verify=False)
def test_https_big_download(run_servefile, datadir):
# test with about 10 mb of data
data = "x" * (10 * 1024 ** 2)
p = datadir({'testfile': data}) / 'testfile'
run_servefile(['--ssl', str(p)])
time.sleep(0.2) # time for generating ssl certificates
urllib3.disable_warnings()
check_download(data, protocol='https', verify=False)