Compare commits

...

2 Commits

2 changed files with 45 additions and 21 deletions

View File

@ -13,8 +13,8 @@ import argparse
import base64
import cgi
import datetime
import io
import mimetypes
import urllib
import os
import re
import select
@ -27,10 +27,12 @@ import time
try:
import BaseHTTPServer
import SocketServer
from urllib import quote, unquote
except ImportError:
# both have different names in python3
import http.server as BaseHTTPServer
import socketserver as SocketServer
from urllib.parse import quote, unquote
# only activate SSL if available
HAVE_SSL = False
@ -56,7 +58,7 @@ class FileBaseHandler(BaseHTTPServer.BaseHTTPRequestHandler):
Returns True if a redirect was issued. """
if not fileName:
fileName = self.fileName
if urllib.unquote(self.path) != "/" + fileName:
if unquote(self.path) != "/" + fileName:
self.send_response(302)
self.send_header('Location', '/' + fileName)
self.end_headers()
@ -346,12 +348,12 @@ class DirListingHandler(FileBaseHandler):
<p>The requestet URL %s was not found on this server</p>
<p><a href="/">Back to /</a>
</body>
</html>""" % self.escapeHTML(urllib.unquote(self.path))
</html>""" % self.escapeHTML(unquote(self.path))
self.send_header("Content-Length", str(len(errorMsg)))
self.send_header('Connection', 'close')
self.end_headers()
if not head:
self.wfile.write(errorMsg)
self.wfile.write(errorMsg.encode())
def escapeHTML(self, htmlstr):
entities = [("<", "&lt;"), (">", "&gt;")]
@ -378,7 +380,7 @@ class DirListingHandler(FileBaseHandler):
<td class="size">%s</td>
<td class="type">%s</td>
</tr>
""" % (urllib.quote(item), item, lastModified, fileSize, fileType))
""" % (quote(item), item, lastModified, fileSize, fileType))
def sendDirectoryListing(self, path, head):
""" Generate a directorylisting for path and send it """
@ -414,7 +416,7 @@ class DirListingHandler(FileBaseHandler):
</tr>
</thead>
<tbody>
""" % {'path': os.path.normpath(urllib.unquote(self.path))}
""" % {'path': os.path.normpath(unquote(self.path))}
footer = """</tbody></table></div>
<div class="footer"><a href="http://seba-geek.de/stuff/servefile/">servefile %(version)s</a></div>
<script>
@ -540,7 +542,7 @@ class DirListingHandler(FileBaseHandler):
self.send_header("Content-Length", str(len(listing)))
self.send_header('Connection', 'close')
self.end_headers()
self.wfile.write(listing)
self.wfile.write(listing.encode())
def convertSize(self, size):
for ext in "KMGT":
@ -552,7 +554,7 @@ class DirListingHandler(FileBaseHandler):
return (size, ext.strip())
def getCleanPath(self):
urlPath = os.path.normpath(urllib.unquote(self.path)).strip("/")
urlPath = os.path.normpath(unquote(self.path)).strip("/")
path = os.path.join(self.targetDir, urlPath)
return path
@ -594,7 +596,8 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
length = self.getContentLength()
if length < 0:
return
ctype = self.headers.getheader('Content-Type')
print(self.headers)
ctype = self.headers.get('Content-Type')
# check for multipart/form-data.
if not (ctype and ctype.lower().startswith("multipart/form-data")):
@ -615,7 +618,7 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
return
# write file down to disk, send a 200 afterwards
target = open(destFileName, "w")
target = open(destFileName, "wb")
bytesLeft = length
while bytesLeft > 0:
bytesToRead = min(self.blockSize, bytesLeft)
@ -638,7 +641,7 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
if length < 0:
return
fileName = urllib.unquote(self.path)
fileName = unquote(self.path)
if fileName == "/":
# if no filename was given we have to generate one
fileName = str(time.time())
@ -685,7 +688,7 @@ class FilePutter(BaseHTTPServer.BaseHTTPRequestHandler):
self.send_header('Content-Length', str(len(msg)))
self.send_header('Connection', 'close')
self.end_headers()
self.wfile.write(msg)
self.wfile.write(msg.encode())
def getTargetName(self, fname):
""" Generate a clean and secure filename.
@ -769,6 +772,17 @@ class SecureThreadedHTTPServer(ThreadedHTTPServer):
class SecureHandler():
def setup(self):
self.connection = self.request
if sys.version_info[0] > 2:
# python3 SocketIO (replacement for socket._fileobject)
raw_read_sock = socket.SocketIO(self.request, 'rb')
raw_write_sock = socket.SocketIO(self.request, 'wb')
rbufsize = self.rbufsize > 0 and self.rbufsize or io.DEFAULT_BUFFER_SIZE
wbufsize = self.wbufsize > 0 and self.wbufsize or io.DEFAULT_BUFFER_SIZE
self.rfile = io.BufferedReader(raw_read_sock, rbufsize)
self.wfile = io.BufferedWriter(raw_write_sock, wbufsize)
else:
# python2 does not have SocketIO
self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
@ -831,7 +845,7 @@ class ServeFile():
# we couldn't find any ip address
proc = None
if proc:
ips = proc.stdout.read().strip().split("\n")
ips = proc.stdout.read().decode().strip().split("\n")
# filter out ips we are not listening on
if not self.listenIPv6:
@ -877,7 +891,7 @@ class ServeFile():
for ip in self.getIPs() + ["127.0.0.1", "::1"]:
altNames.append("IP:%s" % ip)
altNames.append("DNS:localhost")
ext = crypto.X509Extension("subjectAltName", False, ",".join(altNames))
ext = crypto.X509Extension(b"subjectAltName", False, (",".join(altNames)).encode())
req.add_extensions([ext])
req.set_pubkey(pkey)
@ -901,8 +915,8 @@ class ServeFile():
self.key = pkey
print("done.")
print("SHA1 fingerprint:", cert.digest("sha1"))
print("MD5 fingerprint:", cert.digest("md5"))
print("SHA1 fingerprint:", cert.digest("sha1").decode())
print("MD5 fingerprint:", cert.digest("md5").decode())
def _getCert(self):
return self.cert
@ -913,7 +927,7 @@ class ServeFile():
def setAuth(self, user, password, realm=None):
if not user or not password:
raise ServeFileException("User and password both need to be at least one character.")
self.auth = base64.b64encode("%s:%s" % (user, password))
self.auth = base64.b64encode(("%s:%s" % (user, password)).encode()).decode()
self.authrealm = realm
def _createServer(self, handler, withv6=False):
@ -975,7 +989,7 @@ class ServeFile():
else:
pwPart = ""
if self.auth:
pwPart = base64.b64decode(self.auth) + "@"
pwPart = base64.b64decode(self.auth).decode() + "@"
for ip in ips:
if ":" in ip:
ip = "[%s]" % ip
@ -1091,7 +1105,7 @@ class AuthenticationHandler():
errorMsg = "<html><head><title>401 - Unauthorized</title></head><body><h1>401 - Unauthorized</h1></body></html>"
self.send_header("Content-Length", str(len(errorMsg)))
self.end_headers()
self.wfile.write(errorMsg)
self.wfile.write(errorMsg.encode())
def main():

View File

@ -274,3 +274,13 @@ def test_https(run_servefile, datadir):
# assert fingerprint
urllib3.disable_warnings()
check_download(data, protocol='https', verify=False)
def test_https_big_download(run_servefile, datadir):
# test with about 10 mb of data
data = "x" * (10 * 1024 ** 2)
p = datadir({'testfile': data}) / 'testfile'
run_servefile(['--ssl', str(p)])
time.sleep(0.2) # time for generating ssl certificates
urllib3.disable_warnings()
check_download(data, protocol='https', verify=False)