GmCapsule [gsorg-style]

Worker threads; configuring CGI

=> 528c1f08cf847bb1f708fca851e0a480d9fcf18c

diff --git a/examplerc b/examplerc
new file mode 100644
index 0000000..9180a2f
--- /dev/null
+++ b/examplerc
@@ -0,0 +1,7 @@
+[cgi.upload]
+protocol = titan
+path = /cat
+command = cat
+
+[cgi.test]
+command = printenv
diff --git a/gemini.py b/gemini.py
index 7297a3b..fec004e 100644
--- a/gemini.py
+++ b/gemini.py
@@ -3,7 +3,9 @@
 
 import fnmatch
 import hashlib
+import queue
 import socket
+import threading
 import time
 from urllib.parse import urlparse
 
@@ -12,6 +14,7 @@ from OpenSSL import SSL, crypto
 
 
 def report_error(stream, code, msg):
+    print(time.strftime('%Y-%m-%d %H:%M:%S'), f'   ', '--', code, msg)
     stream.sendall(f'{code} {msg}\r\n'.encode('utf-8'))
 
 
@@ -54,6 +57,8 @@ def verify_callback(connection, cert, err_num, err_depth, ret_code):
 
 
 class Cache:
+    """Response cache."""
+
     def __init__(self):
         pass
 
@@ -65,112 +70,35 @@ class Cache:
         return
 
 
-class Server:
-    def __init__(self, hostname_or_hostnames, cert_path, key_path,
-                 address='localhost', port=1965,
-                 cache=None, session_id=None, max_upload_size=0):
-        self.hostnames = [hostname_or_hostnames] \
-            if type(hostname_or_hostnames) == str else hostname_or_hostnames
-        self.address = address
-        self.port = port
-        self.entrypoints = {'gemini': {}, 'titan': {}}
-        for proto in ['gemini', 'titan']:
-            self.entrypoints[proto] = {}
-            for hostname in self.hostnames:
-                self.entrypoints[proto][hostname] = []
-        self.cache = cache
-        self.max_upload_size = max_upload_size
-
-        self.context = SSL.Context(SSL.TLS_SERVER_METHOD)
-        self.context.use_certificate_file(str(cert_path))
-        self.context.use_privatekey_file(str(key_path))
-        self.context.set_verify(SSL.VERIFY_PEER, verify_callback)
-        if session_id:
-            if type(session_id) != bytes:
-                raise Exception("session_id type must be `bytes`")
-            self.context.set_session_id(session_id)
-
-        attempts = 60
-        print(f'Opening port {port}...')
-        while True:
-            try:
-                self.sock = socket.socket()
-                self.sock.bind((address, port))
-                self.sock.listen(5)
-                self.sv_conn = SSL.Connection(self.context, self.sock)
-                self.sv_conn.set_accept_state()
-                break
-            except:
-                attempts -= 1
-                if attempts == 0:
-                    raise Exception(f'Failed to open port {port} for listening')
-                time.sleep(2.0)
-                print('...')
-        print(f'Server started on port {port}')
-
-    def add_entrypoint(self, protocol, hostname, path_pattern, entrypoint):
-        self.entrypoints[protocol][hostname].append((path_pattern, entrypoint))
-
-    def __setitem__(self, key, value):
-        #if key.endswith('*'):
-        #    self.wild_entrypoints[key[:-1]] = value
-        #else:
-        #    self.entrypoints[key] = value
-        for hostname in self.hostnames:
-            self.add_entrypoint('gemini', hostname, key, value)
+class Worker(threading.Thread):
+    """Thread that processes incoming requests from clients."""
 
-    # def __getitem__(self, key):
-    #     if key.endswith('*'):
-    #         return self.wild_entrypoints[key[:-1]]
-    #     return self.entrypoints[key]
+    def __init__(self, id, server):
+        super().__init__()
+        self.id = id
+        self.server = server
+        self.jobs = server.work_queue
 
     def run(self):
+        print(f'Worker {self.id} started')
         while True:
+            job = self.jobs.get()
+            if job is None:
+                break
+            stream, from_addr = job
             try:
-                stream = None
-                try:
-                    stream, from_addr = self.sv_conn.accept()
-                    #print(stream, from_addr)
-                    self.process_request(stream, from_addr)
-                except Exception as ex:
-                    import traceback
-                    traceback.print_exc()
-                    print(ex)
-                    if stream:
-                        report_error(stream, 42, str(ex))
-                finally:
-                    if stream:
-                        stream.shutdown()
-                        #print('Goodbye', from_addr)
-            except Exception as ex:
-                print(ex)
+                self.process_request(stream, from_addr)
+            except Exception as error:
+                report_error(stream, 42, str(error))
+            finally:
+                stream.shutdown()
 
-    def find_entrypoint(self, protocol, hostname, path):
-        try:
-            for entry in self.entrypoints[protocol][hostname]:
-                if len(entry[0]) == 0 or fnmatch.fnmatch(path, entry[0]):
-                    return entry[1]
-        except:
-            return None
+        print(f'Worker {self.id} stopped')
 
-        # # Check the more specific virtual host entrypoint first.
-        # virt_path = f":{hostname}:{path}"
-        # if virt_path in self.entrypoints:
-        #     return self.entrypoints[virt_path]
-        # for entry in self.wild_entrypoints:
-        #     if virt_path.startswith(entry):
-        #         return self.wild_entrypoints[entry]
-
-        # if path in self.entrypoints:
-        #     return self.entrypoints[path]
-        # for entry in self.wild_entrypoints:
-        #     if path.startswith(entry):
-        #         return self.wild_entrypoints[entry]
-
-        return None
+    def log(self, *args):
+        print(time.strftime('%Y-%m-%d %H:%M:%S'), f'[{self.id}]', '--', *args)
 
     def process_request(self, stream, from_addr):
-        print(time.strftime('%Y-%m-%d %H:%M:%S'))
         data = bytes()
         MAX_LEN = 1024
         request = None
@@ -179,8 +107,6 @@ class Server:
         req_mime = None
         incoming = stream.recv(MAX_LEN)
 
-        print(dir(stream))
-
         while len(data) < MAX_LEN:
             data += incoming
             crlf_pos = data.find(b'\r\n')
@@ -199,7 +125,14 @@ class Server:
             report_error(stream, 59, "Unsupported protocol")
             return
 
+        cl_cert = stream.get_peer_certificate()
+        identity = Identity(cl_cert) if cl_cert else None
+
         if request.startswith('titan:'):
+            if identity is None and self.server.require_upload_identity:
+                report_error(stream, 60, "Client certificate required for upload")
+                return
+
             # Read the rest of the data.
             parms = request.split(';')
             request = parms[0]
@@ -210,7 +143,8 @@ class Server:
                     req_token = p[6:]
                 elif p.startswith('mime='):
                     req_mime = p[5:]
-            if expected_size > self.max_upload_size and self.max_upload_size > 0:
+            self.log(f'Receiving Titan content: {expected_size}')
+            if expected_size > self.server.max_upload_size and self.server.max_upload_size > 0:
                 report_error(stream, 59, "Maximum content length exceeded")
                 return
             while len(data) < expected_size:
@@ -227,17 +161,25 @@ class Server:
                 report_error(stream, 59, "Gemini disallows request content")
                 return
 
+        self.log(request)
+
         url = urlparse(request)
-        cl_cert = stream.get_peer_certificate()
-        identity = Identity(cl_cert) if cl_cert else None
         path = url.path
         if path == '':
             path = '/'
-        # TODO: get TLS SNI
         hostname = url.hostname
-        entrypoint = self.find_entrypoint(url.scheme, hostname, path)
-        print(entrypoint)
-        cache = None if identity or len(url.query) > 0 else self.cache
+        entrypoint = self.server.find_entrypoint(url.scheme, hostname, path)
+
+        # Server name indication is required.
+        if not stream.get_servername():
+            report_error(stream, 59, "Missing TLS server name indication")
+            return
+        if stream.get_servername().decode() != url.hostname:
+            report_error(stream, 53, "Proxy request refused")
+            return
+
+        cache = None if (url.scheme != 'gemini' or identity or len(url.query) > 0) \
+            else self.server.cache
         is_from_cache = False
 
         # print(f'Request : {request}')
@@ -246,7 +188,7 @@ class Server:
         if entrypoint:
             # Check the cache first.
             if cache:
-                media, content = cache.try_load(path)
+                media, content = cache.try_load(hostname + path)
                 if not media is None:
                     response = 20, media, content
                     is_from_cache = True
@@ -291,6 +233,125 @@ class Server:
 
             # Save to cache.
             if not is_from_cache and cache and status == 20:
-                cache.save(path, meta, response_data)
+                cache.save(hostname + path, meta, response_data)
         else:
             report_error(stream, 50, 'Permanent failure')
+
+
+class Server:
+    def __init__(self, hostname_or_hostnames, cert_path, key_path,
+                 address='localhost', port=1965,
+                 cache=None, session_id=None, max_upload_size=0, num_threads=1,
+                 require_upload_identity=True):
+        self.hostnames = [hostname_or_hostnames] \
+            if type(hostname_or_hostnames) == str else hostname_or_hostnames
+        self.address = address
+        self.port = port
+        self.entrypoints = {'gemini': {}, 'titan': {}}
+        for proto in ['gemini', 'titan']:
+            self.entrypoints[proto] = {}
+            for hostname in self.hostnames:
+                self.entrypoints[proto][hostname] = []
+        self.cache = cache
+        self.max_upload_size = max_upload_size
+        self.require_upload_identity = require_upload_identity
+
+        self.context = SSL.Context(SSL.TLS_SERVER_METHOD)
+        self.context.use_certificate_file(str(cert_path))
+        self.context.use_privatekey_file(str(key_path))
+        self.context.set_verify(SSL.VERIFY_PEER, verify_callback)
+        if session_id:
+            if type(session_id) != bytes:
+                raise Exception("session_id type must be `bytes`")
+            self.context.set_session_id(session_id)
+
+        # Spawn the worker threads.
+        self.workers = []
+        self.work_queue = queue.Queue()
+        for worker_id in range(max(num_threads, 1)):
+            worker = Worker(worker_id, self)
+            self.workers.append(worker)
+
+        attempts = 60
+        print(f'Opening port {port}...')
+        while True:
+            try:
+                self.sock = socket.socket()
+                self.sock.bind((address, port))
+                self.sock.listen(5)
+                self.sv_conn = SSL.Connection(self.context, self.sock)
+                self.sv_conn.set_accept_state()
+                break
+            except:
+                attempts -= 1
+                if attempts == 0:
+                    raise Exception(f'Failed to open port {port} for listening')
+                time.sleep(2.0)
+                print('...')
+        print(f'Server started on port {port}')
+
+    def add_entrypoint(self, protocol, hostname, path_pattern, entrypoint):
+        self.entrypoints[protocol][hostname].append((path_pattern, entrypoint))
+
+    def __setitem__(self, key, value):
+        #if key.endswith('*'):
+        #    self.wild_entrypoints[key[:-1]] = value
+        #else:
+        #    self.entrypoints[key] = value
+        for hostname in self.hostnames:
+            self.add_entrypoint('gemini', hostname, key, value)
+
+    # def __getitem__(self, key):
+    #     if key.endswith('*'):
+    #         return self.wild_entrypoints[key[:-1]]
+    #     return self.entrypoints[key]
+
+    def run(self):
+        for worker in self.workers:
+            worker.start()
+        while True:
+            try:
+                stream = None
+                try:
+                    stream, from_addr = self.sv_conn.accept()
+                    #print(stream, from_addr)
+                    #self.process_request(stream, from_addr)
+                    self.work_queue.put((stream, from_addr))
+                except KeyboardInterrupt:
+                    print('\nStopping the server...')
+                    break
+                except Exception as ex:
+                    import traceback
+                    traceback.print_exc()
+                    print(ex)
+            except Exception as ex:
+                print(ex)
+        for i in range(len(self.workers)):
+            self.work_queue.put(None)
+        for worker in self.workers:
+            worker.join()
+
+    def find_entrypoint(self, protocol, hostname, path):
+        try:
+            for entry in self.entrypoints[protocol][hostname]:
+                if len(entry[0]) == 0 or fnmatch.fnmatch(path, entry[0]):
+                    return entry[1]
+        except:
+            return None
+
+        # # Check the more specific virtual host entrypoint first.
+        # virt_path = f":{hostname}:{path}"
+        # if virt_path in self.entrypoints:
+        #     return self.entrypoints[virt_path]
+        # for entry in self.wild_entrypoints:
+        #     if virt_path.startswith(entry):
+        #         return self.wild_entrypoints[entry]
+
+        # if path in self.entrypoints:
+        #     return self.entrypoints[path]
+        # for entry in self.wild_entrypoints:
+        #     if path.startswith(entry):
+        #         return self.wild_entrypoints[entry]
+
+        return None
+
diff --git a/gmcapsule.py b/gmcapsule.py
index 4886141..bc0c2a4 100644
--- a/gmcapsule.py
+++ b/gmcapsule.py
@@ -13,28 +13,43 @@ import gemini
 
 
 class Config:
-    def __init__(self):
-        # TODO: Get this using configparser
-        self.hostnames = ['localhost']
-        self.address = '0.0.0.0'
-        self.port = 1965
-        self.certs_dir = Path('.certs')
-        self.root_dir = Path('.')        # vhosts as subdirs
-        self.mod_dir = Path('modules')   # extension modules
-        self.max_upload_size = 10 * 1024 * 1024
-        self.cgi = {
-            'gemini': {
-                'localhost': [
-                    ('/test', ['/bin/ls', '-l'])
-                ]
-            },
-            'titan': {
-                'localhost': [
-                    ('/test', ['printenv']),
-                    ('/test/*', ['printenv'])
-                ]
-            }
-        }
+    def __init__(self, config_path):
+        self.ini = configparser.ConfigParser()
+        if os.path.exists(config_path):
+            self.ini.read(config_path)
+        else:
+            print(config_path, 'not found -- using defaults')
+
+    def hostnames(self):
+        return self.ini.get('server', 'host', fallback='localhost').split()
+
+    def address(self):
+        return self.ini.get('server', 'address', fallback='0.0.0.0')
+
+    def port(self):
+        return self.ini.getint('server', 'port', fallback=1965)
+
+    def certs_dir(self):
+        return Path(self.ini.get('server', 'certs', fallback='.certs'))
+
+    def root_dir(self):
+        return Path(self.ini.get('server', 'root', fallback='.'))
+
+    def mod_dir(self):
+        return Path(self.ini.get('server', 'modules', fallback='modules'))
+
+    def num_threads(self):
+        return self.ini.getint('server', 'threads', fallback=5)
+
+    def max_upload_size(self):
+        return self.ini.getint('titan', 'upload_size', fallback=10 * 1024 * 1024)
+
+    def prefixed_sections(self, prefix):
+        sects = {}
+        for name in self.ini.sections():
+            if not name.startswith(prefix): continue
+            sects[name[len(prefix):]] = self.ini[name]
+        return sects
 
 
 class Capsule:
@@ -44,13 +59,14 @@ class Capsule:
         Capsule._capsule = self
         self.cfg = cfg
         self.sv = gemini.Server(
-            cfg.hostnames,
-            cfg.certs_dir / 'cert.pem',
-            cfg.certs_dir / 'key.pem',
-            address=cfg.address,
-            port=cfg.port,
-            session_id=f'GmCapsule:{cfg.port}'.encode('utf-8'),
-            max_upload_size=cfg.max_upload_size
+            cfg.hostnames(),
+            cfg.certs_dir() / 'cert.pem',
+            cfg.certs_dir() / 'key.pem',
+            address=cfg.address(),
+            port=cfg.port(),
+            session_id=f'GmCapsule:{cfg.port()}'.encode('utf-8'),
+            max_upload_size=cfg.max_upload_size(),
+            num_threads=cfg.num_threads()
         )
         # Modules define the entrypoints.
         self.load_modules()
@@ -63,21 +79,21 @@ class Capsule:
         if hostname:
             self.sv.add_entrypoint(protocol, hostname, path, entrypoint)
         else:
-            for hostname in self.cfg.hostnames:
+            for hostname in self.cfg.hostnames():
                 if not hostname:
                     raise Exception(f'invalid hostname: "{hostname}"')
                 self.sv.add_entrypoint(protocol, hostname, path, entrypoint)
 
     def load_modules(self):
-        for mod_file in sorted(os.listdir(self.cfg.mod_dir)):
+        for mod_file in sorted(os.listdir(self.cfg.mod_dir())):
             if mod_file.endswith('.py'):
-                path = (self.cfg.mod_dir / mod_file).resolve()
+                path = (self.cfg.mod_dir() / mod_file).resolve()
                 name = mod_file[:-3]
-                print('Module:', name)
                 loader = importlib.machinery.SourceFileLoader(name, str(path))
                 spec = importlib.util.spec_from_loader(name, loader)
                 mod = importlib.util.module_from_spec(spec)
                 loader.exec_module(mod)
+                print('Module:', mod.__doc__)
                 mod.init(self)
 
     def run(self):
diff --git a/gmcapsuled b/gmcapsuled
index cab942e..212401d 100755
--- a/gmcapsuled
+++ b/gmcapsuled
@@ -5,17 +5,21 @@
 # License: BSD 2-Clause
 
 import argparse
+import os
 import sys
 import threading
 
 from gmcapsule import *
+from pathlib import Path
 
 VERSION = '0.1'
 
 print(f"GmCapsule {VERSION}")
 
-# TODO: Parse command arguments.
+argp = argparse.ArgumentParser(description='GmCapsule is an extensible server for Gemini and Titan.')
+argp.add_argument('-C', '--config', dest='config_file', default=Path.home() / '.gmcapsulerc')
+args = argp.parse_args()
 
-cfg = Config()
+cfg = Config(args.config_file)
 capsule = Capsule(cfg)
 capsule.run()
diff --git a/modules/400_cgi.py b/modules/40_cgi.py
similarity index 54%
rename from modules/400_cgi.py
rename to modules/40_cgi.py
index c635a85..70bd04a 100644
--- a/modules/400_cgi.py
+++ b/modules/40_cgi.py
@@ -1,5 +1,7 @@
+"""CGI commands"""
 
 import os
+import shlex
 import subprocess
 import urllib.parse
 
@@ -15,7 +17,6 @@ class CgiContext:
 
     def __call__(self, req):
         try:
-            cfg = Capsule.config()
             query = urllib.parse.unquote(req.query)
             env_vars = dict(os.environ)
 
@@ -24,21 +25,29 @@ class CgiContext:
             env_vars['QUERY_STRING'] = req.query
             assert req.path.startswith(self.base_path)
             env_vars['PATH_INFO'] = req.path[len(self.base_path):]
+            env_vars['SERVER_PROTOCOL'] = req.scheme
+            env_vars['SERVER_NAME'] = req.hostname
+            env_vars['SERVER_PORT'] = str(Capsule.config().port())
 
+            # TLS client certificate.
             if req.identity:
-                env_vars['REMOTE_IDENT'] = str(req.identity)
-                env_vars['REMOTE_USER'] = req.identity.subject()
+                env_vars['AUTH_TYPE'] = 'TLS'
+                env_vars['REMOTE_IDENT'] = str(req.identity)      # cert fingerprints
+                env_vars['REMOTE_USER'] = req.identity.subject()  # "/CN=name"
+            else:
+                env_vars['AUTH_TYPE'] = ''
 
+            # Titan metadata.
             if req.content:
-                env_vars['TITAN_TOKEN'] = req.content_token if req.content_token is not None else ''
-                env_vars['TITAN_MIME'] = req.content_mime if req.content_mime is not None else ''
+                env_vars['CONTENT_LENGTH'] = str(len(req.content))
+                env_vars['CONTENT_TYPE'] = req.content_mime if req.content_mime is not None else ''
+                env_vars['CONTENT_TOKEN'] = req.content_token if req.content_token is not None else ''
 
-            print(req.content)
-
-            result = subprocess.run(self.args, check=True,
-                input=req.content,
-                stdout=subprocess.PIPE,
-                env=env_vars).stdout
+            result = subprocess.run(self.args,
+                                    check=True,
+                                    input=req.content,
+                                    stdout=subprocess.PIPE,
+                                    env=env_vars).stdout
             try:
                 # Parse response header.
                 crlf_pos = result.find(b'\r\n')
@@ -62,7 +71,10 @@ class CgiContext:
 
 def init(capsule):
     cfg = Capsule.config()
-    for protocol in cfg.cgi:
-        for hostname in cfg.cgi[protocol]:
-            for entry in cfg.cgi[protocol][hostname]:
-                capsule.add(entry[0], CgiContext(entry[0], entry[1]), hostname, protocol)
+    default_host = cfg.hostnames()[0]
+    for section in Capsule.config().prefixed_sections('cgi.').values():
+        protocol = section.get('protocol', fallback='gemini')
+        host = section.get('host', fallback=default_host)
+        path = section.get('path', fallback='/*')
+        args = shlex.split(section.get('command'))
+        capsule.add(path, CgiContext(path, args), host, protocol)
diff --git a/modules/500_static.py b/modules/50_static.py
similarity index 91%
rename from modules/500_static.py
rename to modules/50_static.py
index a525993..7a9f95f 100644
--- a/modules/500_static.py
+++ b/modules/50_static.py
@@ -1,3 +1,5 @@
+"""Static files from the host content directory"""
+
 import fnmatch
 import os.path
 import string
@@ -10,7 +12,7 @@ META = '.meta'
 
 def check_meta_rules(path, hostname):
     cfg = Capsule.config()
-    root = (cfg.root_dir / hostname).resolve()
+    root = (cfg.root_dir() / hostname).resolve()
     dir = path.parent
     while True:
         if not str(dir.resolve()).startswith(str(root)):
@@ -44,7 +46,7 @@ def serve_file(req):
         if seg != '.' and seg != '..' and seg.startswith('.'):
             return 51, "Not found"
 
-    host_root = (cfg.root_dir / req.hostname).resolve()
+    host_root = (cfg.root_dir() / req.hostname).resolve()
     path = (host_root / req.path[1:]).resolve()
     if not str(path).startswith(str(host_root)):
         return 51, "Not found"
Proxy Information
Original URL
gemini://git.skyjake.fi/gmcapsule/gsorg-style/cdiff/528c1f08cf847bb1f708fca851e0a480d9fcf18c
Status Code
Success (20)
Meta
text/gemini; charset=utf-8
Capsule Response Time
31.993992 milliseconds
Gemini-to-HTML Time
1.2082 milliseconds

This content has been proxied by September (ba2dc).