aboutsummaryrefslogtreecommitdiff
path: root/src_py
diff options
context:
space:
mode:
Diffstat (limited to 'src_py')
-rw-r--r--src_py/hatter/__init__.py0
-rw-r--r--src_py/hatter/__main__.py8
-rw-r--r--src_py/hatter/backend.py171
-rw-r--r--src_py/hatter/common.py10
-rw-r--r--src_py/hatter/executor.py197
-rw-r--r--src_py/hatter/main.py189
-rw-r--r--src_py/hatter/server.py163
-rw-r--r--src_py/hatter/util.py147
8 files changed, 133 insertions, 752 deletions
diff --git a/src_py/hatter/__init__.py b/src_py/hatter/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src_py/hatter/__init__.py
diff --git a/src_py/hatter/__main__.py b/src_py/hatter/__main__.py
new file mode 100644
index 0000000..065fe88
--- /dev/null
+++ b/src_py/hatter/__main__.py
@@ -0,0 +1,8 @@
+import sys
+
+from hatter.main import main
+
+
+if __name__ == '__main__':
+ sys.argv[0] = 'hatter'
+ sys.exit(main())
diff --git a/src_py/hatter/backend.py b/src_py/hatter/backend.py
deleted file mode 100644
index 8b78570..0000000
--- a/src_py/hatter/backend.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import sqlite3
-import datetime
-import threading
-import logging
-import concurrent.futures
-import asyncio
-
-from hatter import util
-from hatter import executor
-
-
-util.monkeypatch_sqlite3()
-
-
-LogEntry = util.namedtuple('LogEntry',
- ['timestamp', 'datetime.datetime: timestamp'],
- ['repository', 'str: repository'],
- ['commit', 'str: commit'],
- ['msg', 'str: message'])
-
-Job = util.namedtuple('Job',
- ['id', 'int: id'],
- ['timestamp', 'datetime.datetime: timestamp'],
- ['repository', 'str: repository'],
- ['commit', 'str: commit'])
-
-
-class Backend:
-
- def __init__(self, db_path, repositories):
- self._repositories = repositories
- self._next_job_id = 0
- self._active = None
- self._queue = []
- self._active_change_cbs = util.CallbackRegistry()
- self._queue_change_cbs = util.CallbackRegistry()
- self._log_change_cbs = util.CallbackRegistry()
- self._cv = asyncio.Condition()
- self._db = _DB(db_path)
- self._executor = concurrent.futures.ThreadPoolExecutor()
- self._run_loop_future = asyncio.ensure_future(self._run_loop())
-
- @property
- def repositories(self):
- return self._repositories
-
- @property
- def active(self):
- return self._active
-
- @property
- def queue(self):
- return self._queue
-
- def register_active_change_cb(self, cb):
- return self._active_change_cbs.register(cb)
-
- def register_queue_change_cb(self, cb):
- return self._queue_change_cbs.register(cb)
-
- def register_log_change_cb(self, cb):
- return self._log_change_cbs.register(cb)
-
- async def async_close(self):
- self._run_loop_future.cancel()
- await self._run_loop_future
-
- async def query_log(self, offset, limit):
- return await asyncio.get_event_loop().run_in_executor(
- self._executor, self._db.query, offset, limit)
-
- async def add_job(self, repository, commit):
- job = Job(id=self._next_job_id,
- timestamp=datetime.datetime.now(datetime.timezone.utc),
- repository=repository,
- commit=commit)
- self._next_job_id += 1
- with await self._cv:
- self._queue.append(job)
- self._cv.notify_all()
- self._queue_change_cbs.notify()
-
- async def _run_loop(self):
- log = logging.getLogger('hatter.project')
- while True:
- with await self._cv:
- while not self._queue:
- await self._cv.wait()
- self._active = self._queue_change_cbs.pop(0)
- self._queue_change_cbs.notify()
- self._active_change_cbs.notify()
-
- handler = _LogHandler(
- self._db, self._active.repository, self._active.commit)
- log.addHandler(handler)
- try:
- await asyncio.get_event_loop().run_in_executor(
- self._executor, executor.run,
- log, self._active.repository, self._active.commit)
- except asyncio.CancelledError:
- break
- except Exception as e:
- log.error("%s", e, exc_info=True)
- finally:
- log.removeHandler(handler)
- self._active = None
- self._active_change_cbs.notify()
-
-
-class _LogHandler(logging.Handler):
-
- def __init__(self, db, repository, commit):
- super().__init__()
- self._db = db
- self._repository
- self._commit = commit
-
- def emit(self, record):
- self._db.add(
- timestamp=datetime.datetime.fromtimestamp(
- record.created, datetime.timezone.utc),
- repository=self._repository,
- commit=self._commit,
- msg=record.getMessage())
-
-
-class _DB:
-
- def __init__(self, db_path):
- db_path.parent.mkdir(exist_ok=True)
- self._db = sqlite3.connect('file:{}?nolock=1'.format(db_path),
- uri=True,
- isolation_level=None,
- detect_types=sqlite3.PARSE_DECLTYPES)
- self._db.executescript("CREATE TABLE IF NOT EXISTS log ("
- "timestamp TIMESTAMP, "
- "repository TEXT, "
- "commit TEXT, "
- "msg TEXT)")
- self._db.commit()
- self._lock = threading.Lock()
-
- def close(self):
- with self._lock:
- self._db.close()
-
- def add(self, timestamp, repository, commit, msg):
- with self._lock:
- self._db.execute(
- "INSERT INTO log VALUES "
- "(:timestamp, :repository, :commit, :msg)",
- {'timestamp': timestamp,
- 'repository': repository,
- 'commit': commit,
- 'msg': msg})
-
- def query(self, offset, limit):
- with self._lock:
- c = self._db.execute(
- "SELECT rowid, * FROM log ORDER BY rowid DESC "
- "LIMIT :limit OFFSET :offset",
- {'limit': limit, 'offset': offset})
- try:
- result = c.fetchall()
- except Exception as e:
- result = []
- return [LogEntry(timestamp=i[1],
- repository=i[2],
- commit=i[3],
- msg=i[4])
- for i in result]
diff --git a/src_py/hatter/common.py b/src_py/hatter/common.py
new file mode 100644
index 0000000..cf99bb9
--- /dev/null
+++ b/src_py/hatter/common.py
@@ -0,0 +1,10 @@
+from pathlib import Path
+
+from hat import json
+
+
+package_path: Path = Path(__file__).parent
+
+json_schema_repo: json.SchemaRepository = json.SchemaRepository(
+ json.json_schema_repo,
+ json.SchemaRepository.from_json(package_path / 'json_schema_repo.json'))
diff --git a/src_py/hatter/executor.py b/src_py/hatter/executor.py
deleted file mode 100644
index 2d1ae2a..0000000
--- a/src_py/hatter/executor.py
+++ /dev/null
@@ -1,197 +0,0 @@
-import tempfile
-import pathlib
-import tarfile
-import subprocess
-import io
-import time
-import contextlib
-import yaml
-import libvirt
-import paramiko
-
-import hatter.json_validator
-
-
-def run(log, repo_path, commit='HEAD', archive_name='hatter_archive'):
- log.info('starting executor for repository {} ({})'.format(repo_path,
- commit))
- t_begin = time.monotonic()
- archive_file_name = archive_name + '.tar.gz'
- with tempfile.TemporaryDirectory() as tempdir:
- archive_path = pathlib.Path(tempdir) / archive_file_name
- log.info('fetching remote repository')
- _git_archive(repo_path, commit, archive_path)
- log.info('loading project configuration')
- conf = _load_conf(archive_path)
- for i in conf:
- log.info('starting virtual machine')
- with contextlib.closing(_VM(i['vm'])) as vm:
- log.info('creating SSH connection')
- with contextlib.closing(_SSH(i['ssh'], vm.address)) as ssh:
- log.info('transfering repository to virtual machine')
- ssh.execute('rm -rf {} {}'.format(archive_file_name,
- archive_name))
- ssh.upload(archive_path, archive_file_name)
- ssh.execute('mkdir {}'.format(archive_name))
- ssh.execute('tar xf {} -C {}'.format(archive_file_name,
- archive_name))
- log.info('executing scripts')
- for script in i['scripts']:
- ssh.execute(script, archive_name, log)
- t_end = time.monotonic()
- log.info('executor finished (duration: {}s)'.format(t_end - t_begin))
-
-
-class _VM:
-
- def __init__(self, conf):
- self._conn = None
- self._domain = None
- self._temp_snapshot = None
- self._address = None
- try:
- self._conn = _libvirt_connect(conf.get('uri', 'qemu:///system'))
- self._domain = _libvirt_get_domain(self._conn, conf['domain'])
- if self._domain.isActive():
- self._domain.destroy()
- self._temp_snapshot = _libvirt_create_temp_snapshot(
- self._domain, conf.get('temp_snapshot', 'temp_hatter'))
- if 'snapshot' in conf:
- _libvirt_revert_snapshot(self._domain, conf['snapshot'])
- _libvirt_start_domain(self._domain)
- for _ in range(conf('get_address_retry_count', 10)):
- self._address = _libvirt_get_address(self._domain)
- if self._address:
- return
- time.sleep(conf.get('get_address_delay', 5))
- raise Exception('ip addess not detected')
- except Exception:
- self.close()
- raise
-
- @property
- def address(self):
- return self._address
-
- def close(self):
- if self._domain:
- self._domain.destroy()
- if self._domain and self._temp_snapshot:
- self._domain.revertToSnapshot(self._temp_snapshot)
- if self._temp_snapshot:
- self._temp_snapshot.delete()
- if self._conn:
- self._conn.close()
- self._temp_snapshot = None
- self._domain = None
- self._conn = None
- self._address = None
-
-
-class _SSH:
-
- def __init__(self, conf, address):
- self._conn = paramiko.SSHClient()
- self._conn.set_missing_host_key_policy(paramiko.AutoAddPolicy)
- for _ in range(conf.get('connect_retry_count', 10)):
- try:
- self._conn.connect(
- address,
- username=conf['username'], password=conf['password'],
- timeout=conf.get('connect_timeout', 1),
- auth_timeout=conf.get('connect_timeout', 1))
- return
- except Exception as e:
- time.sleep(conf.get('connect_delay', 5))
- raise Exception('could not connect to {}'.format(address))
-
- def close(self):
- if self._conn:
- self._conn.close()
- self._conn = None
-
- def upload(self, src_path, dst_path):
- with contextlib.closing(self._conn.open_sftp()) as sftp:
- sftp.put(str(src_path), str(dst_path))
-
- def execute(self, cmd, cwd='.', log=None):
- if log:
- log.info('executing command: {}'.format(cmd))
- with contextlib.closing(self._conn.invoke_shell()) as shell:
- shell.set_combine_stderr(True)
- shell.exec_command('cd {} && {}'.format(cwd, cmd))
- with contextlib.closing(shell.makefile()) as f:
- data = f.read()
- if log:
- log.info('command output: {}'.format(data))
- exit_code = shell.recv_exit_status()
- if exit_code > 0:
- raise Exception('command exit code is {}'.format(exit_code))
-
-
-def _load_conf(archive_path):
- with tarfile.open(archive_path) as archive:
- f = io.TextIOWrapper(archive.extractfile('.hatter.yml'),
- encoding='utf-8')
- conf = yaml.safe_load(f)
- hatter.json_validator.validate(conf, 'hatter://project.yaml#')
- return conf
-
-
-def _git_archive(repo_path, commit, output_path):
- result = subprocess.run(
- ['git', 'archive', '--format=tar.gz',
- '--outfile={}'.format(str(output_path)),
- '--remote={}'.format(repo_path),
- commit],
- stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
- if result.returncode:
- raise Exception("could not archive {} from {}".format(commit,
- repo_path))
-
-
-def _libvirt_connect(uri):
- conn = libvirt.open(uri)
- if not conn:
- raise Exception('could not open connection to {}'.format(uri))
- return conn
-
-
-def _libvirt_get_domain(conn, domain_name):
- domain = conn.lookupByName(domain_name)
- if not domain:
- raise Exception('domain {} not available'.format(domain_name))
- return domain
-
-
-def _libvirt_start_domain(domain):
- if domain.create():
- raise Exception('could not run vm')
-
-
-def _libvirt_create_temp_snapshot(domain, temp_snapshot_name):
- temp_snapshot = domain.snapshotLookupByName(temp_snapshot_name)
- if temp_snapshot:
- temp_snapshot.delete()
- temp_snapshot = domain.snapshotCreateXML(
- "<domainsnapshot><name>{}</name></domainsnapshot>".format(
- temp_snapshot_name))
- if not temp_snapshot:
- raise Exception('could not create snapshot {}'.format(
- temp_snapshot_name))
- return temp_snapshot
-
-
-def _libvirt_revert_snapshot(domain, snapshot_name):
- snapshot = domain.snapshotLookupByName(snapshot_name)
- if not snapshot:
- raise Exception('snapshot {} not available'.format(snapshot_name))
- if domain.revertToSnapshot(snapshot):
- raise Exception('could not revert snapshot {}'.format(snapshot_name))
-
-
-def _libvirt_get_address(domain):
- addresses = domain.interfaceAddresses(0)
- for i in addresses.values():
- for j in i.get('addrs', []):
- return j.get('addr')
diff --git a/src_py/hatter/main.py b/src_py/hatter/main.py
index 71245a2..f1046a5 100644
--- a/src_py/hatter/main.py
+++ b/src_py/hatter/main.py
@@ -1,74 +1,115 @@
-import sys
-import asyncio
-import argparse
-import pdb
-import yaml
-import logging.config
-import atexit
-import pkg_resources
-import pathlib
-
-import hatter.json_validator
-from hatter import util
-from hatter.backend import Backend
-from hatter.server import create_web_server
-
-
-def main():
- args = _create_parser().parse_args()
-
- with open(args.conf, encoding='utf-8') as conf_file:
- conf = yaml.safe_load(conf_file)
- hatter.json_validator.validate(conf, 'hatter://server.yaml#')
-
- if 'log' in conf:
- logging.config.dictConfig(conf['log'])
-
- if args.web_path:
- web_path = args.web_path
- else:
- atexit.register(pkg_resources.cleanup_resources)
- web_path = pkg_resources.resource_filename('hatter', 'web')
-
- util.run_until_complete_without_interrupt(async_main(conf, web_path))
-
-
-async def async_main(conf, web_path):
- backend = None
- web_server = None
- try:
- backend = Backend(pathlib.Path(conf.get('db_path', 'hatter.db')),
- conf['repositories'])
- web_server = await create_web_server(
- backend, conf.get('host', '0.0.0.0'), conf.get('port', 24000),
- conf.get('webhook_path', '/webhook'), web_path)
- await asyncio.Future()
- except asyncio.CancelledError:
- pass
- except Exception as e:
- pdb.set_trace()
- raise
- finally:
- if web_server:
- await web_server.async_close()
- if backend:
- await backend.async_close()
- await asyncio.sleep(0.5)
-
-
-def _create_parser():
- parser = argparse.ArgumentParser(prog='hatter')
- parser.add_argument(
- '--web-path', default=None, metavar='path', dest='web_path',
- help="web ui directory path")
-
- named_arguments = parser.add_argument_group('required named arguments')
- named_arguments.add_argument(
- '-c', '--conf', required=True, metavar='path', dest='conf',
- help='configuration path')
-
- return parser
-
-
-if __name__ == '__main__':
- sys.exit(main())
+from pathlib import Path
+import asyncio
+import contextlib
+import logging.config
+import sys
+import tempfile
+import typing
+import subprocess
+
+import appdirs
+import click
+
+from hat import aio
+from hat import json
+
+from hatter import common
+
+
+user_config_dir: Path = Path(appdirs.user_config_dir('hatter'))
+user_data_dir: Path = Path(appdirs.user_data_dir('hatter'))
+
+default_conf_path: Path = user_config_dir / 'server.yaml'
+default_db_path: Path = user_data_dir / 'hatter.db'
+
+ssh_key_path: typing.Optional[Path] = None
+
+
+@click.group()
+@click.option('--log-level',
+ default='INFO',
+ type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO',
+ 'DEBUG', 'NOTSET']),
+ help="log level")
+@click.option('--ssh-key', default=None, metavar='PATH', type=Path,
+ help="private key used for ssh authentication")
+def main(log_level: str,
+ ssh_key: typing.Optional[Path]):
+ global ssh_key_path
+ ssh_key_path = ssh_key
+
+ logging.config.dictConfig({
+ 'version': 1,
+ 'formatters': {
+ 'console': {
+ 'format': "[%(asctime)s %(levelname)s %(name)s] %(message)s"}},
+ 'handlers': {
+ 'console': {
+ 'class': 'logging.StreamHandler',
+ 'formatter': 'console',
+ 'level': log_level}},
+ 'root': {
+ 'level': log_level,
+ 'handlers': ['console']},
+ 'disable_existing_loggers': False})
+
+
+@main.command()
+@click.argument('url', required=True)
+@click.argument('branch', required=False, default='master')
+@click.argument('action', required=False, default='.hatter.yaml')
+def execute(url: str,
+ branch: str,
+ action: str):
+ with tempfile.TemporaryDirectory() as repo_dir:
+ repo_dir = Path(repo_dir)
+
+ subprocess.run(['git', 'clone', '-q', '--depth', '1',
+ '-b', branch, url, str(repo_dir)],
+ check=True)
+
+ conf = json.decode_file(repo_dir / '.hatter.yaml')
+ common.json_schema_repo.validate('hatter://action.yaml#', conf)
+
+ image = conf['image']
+ command = conf['command']
+ subprocess.run(['podman', 'run', '-i', '--rm',
+ '-v', f'{repo_dir}:/hatter',
+ image, '/bin/sh'],
+ input=f'set -e\ncd /hatter\n{command}\n',
+ encoding='utf-8',
+ check=True)
+
+
+@main.command()
+@click.option('--host', default='0.0.0.0',
+ help="listening host name (default 0.0.0.0)")
+@click.option('--port', default=24000, type=int,
+ help="listening TCP port (default 24000)")
+@click.option('--conf', default=default_conf_path, metavar='PATH', type=Path,
+ help="configuration defined by hatter://server.yaml# "
+ "(default $XDG_CONFIG_HOME/hatter/server.yaml)")
+@click.option('--db', default=default_db_path, metavar='PATH', type=Path,
+ help="sqlite database path "
+ "(default $XDG_CONFIG_HOME/hatter/hatter.db")
+def server(host: str,
+ port: int,
+ conf: Path,
+ db: Path):
+ conf = json.decode_file(conf)
+ common.json_schema_repo.validate('hatter://server.yaml#', conf)
+
+ with contextlib.suppress(asyncio.CancelledError):
+ aio.run_asyncio(async_server(host, port, conf, db))
+
+
+async def async_server(host: str,
+ port: int,
+ conf: json.Data,
+ db_path: Path):
+ pass
+
+
+if __name__ == '__main__':
+ sys.argv[0] = 'hatter'
+ main()
diff --git a/src_py/hatter/server.py b/src_py/hatter/server.py
deleted file mode 100644
index baeffac..0000000
--- a/src_py/hatter/server.py
+++ /dev/null
@@ -1,163 +0,0 @@
-import asyncio
-import json
-import aiohttp.web
-
-from hatter import util
-import hatter.json_validator
-
-
-async def create_web_server(backend, host, port, webhook_path, web_path):
- srv = WebServer()
- srv._backend = backend
- srv._app = aiohttp.web.Application()
- srv._app.router.add_route(
- 'GET', '/', lambda req: aiohttp.web.HTTPFound('/index.html'))
- srv._app.router.add_route('*', '/ws', srv._ws_handler)
- srv._app.router.add_route('POST', webhook_path, srv._webhook_handler)
- srv._app.router.add_static('/', web_path)
- srv._app_handler = srv._app.make_handler()
- srv._srv = await asyncio.get_event_loop().create_server(
- srv._app_handler, host=host, port=port)
- return srv
-
-
-class WebServer:
-
- async def async_close(self):
- self._srv.close()
- await self._srv.wait_closed()
- await self._app.shutdown()
- await self._app_handler.finish_connections(0)
- await self._app.cleanup()
-
- async def _ws_handler(self, request):
- ws = aiohttp.web.WebSocketResponse()
- await ws.prepare(request)
- client = _Client(self._backend, ws)
- await client.run()
- return ws
-
- async def _webhook_handler(self, request):
- try:
- if not ({'X-Gitlab-Event', 'X-GitHub-Event'} &
- set(request.headers.keys())):
- raise Exception('unsupported webhook request')
- body = await request.read()
- data = json.loads(body)
- req = _parse_webhook_request(request.headers, data)
- for commit in req.commits:
- self._backend.add_job(req.url, commit)
- except Exception:
- pass
- return aiohttp.web.Response()
-
-
-_WebhookRequest = util.namedtuple('_WebhookRequest', 'url', 'commits')
-
-
-def _parse_webhook_request(headers, data):
- if headers.get('X-Gitlab-Event') == 'Push Hook':
- url = data['repository']['git_http_url']
- commits = [commit['id'] for commit in data['commits']]
- elif headers.get('X-GitHub-Event') == 'push':
- url = data['repository']['clone_url']
- commits = [commit['id'] for commit in data['commits']]
- else:
- raise Exception('unsupported webhook event')
- return _WebhookRequest(url, commits)
-
-
-class _Client:
-
- def __init__(self, backend, ws):
- self._backend = backend
- self._ws = ws
- self._log_offset = 0
- self._log_limit = 0
- self._log_entries = []
- self._active_job = None
- self._job_queue = []
-
- async def run(self):
- self._active_job = self._backend.active
- self._job_queue = list(self._backend.queue)
- with self._backend.register_active_change_cb(self._on_active_change):
- with self._backend.register_queue_change_cb(self._on_queue_change):
- with self._backend.register_log_change_cb(self._on_log_change):
- try:
- self._send_repositories()
- self._send_active_job()
- self._send_job_queue()
- self._send_log_entries()
- while True:
- msg = await self._ws.receive()
- if self._ws.closed:
- break
- if msg.type != aiohttp.WSMsgType.TEXT:
- continue
- json_msg = json.loads(msg.data, encoding='utf-8')
- hatter.json_validator.validate(json_msg, 'hatter://message.yaml#/definitions/client_message') # NOQA
- await self._process_msg(json_msg)
- except Exception as e:
- print('>>>', e)
-
- async def _process_msg(self, msg):
- if msg['type'] == 'set_log':
- self._log_offset = msg['log_offset']
- self._log_limit = msg['log_limit']
- await self._update_log()
- elif msg['type'] == 'add_job':
- await self._backend.add_job(msg['repository'], msg['commit'])
-
- def _on_active_change(self):
- if self._active_job != self._backend.active:
- self._active_job = self._backend.active
- self._send_active_job()
-
- def _on_queue_change(self):
- if self._job_queue != self._backend.job_queue:
- self._job_queue = list(self._backend.job_queue)
- self._send_job_queue()
-
- def _on_log_change(self):
- asyncio.ensure_future(self._update_log())
-
- async def _update_log(self, offset, limit):
- log_entries = await self._backend.query_log(offset, limit)
- if log_entries != self._log_entries:
- self._log_entries = log_entries
- self._send_log_entries()
-
- def _send_repositories(self):
- self._ws.send_str(json.dumps({
- 'type': 'repositories',
- 'repositories': self._backend.repositories}))
-
- def _send_active_job(self):
- self._ws.send_str(json.dumps({
- 'type': 'active_job',
- 'job': _job_to_json(self._active_job)}))
-
- def _send_job_queue(self):
- self._ws.send_str(json.dumps({
- 'type': 'job_queue',
- 'jobs': [_job_to_json(i) for i in self._job_queue]}))
-
- def _send_log_entries(self):
- self._ws.send_str(json.dumps({
- 'type': 'log_entries',
- 'entries': [_log_entry_to_json(i) for i in self._log_entries]}))
-
-
-def _job_to_json(job):
- return {'id': job.id,
- 'timestamp': job.timestamp.timestamp(),
- 'repository': job.repository,
- 'commit': job.commit}
-
-
-def _log_entry_to_json(entry):
- return {'timestamp': entry.timestamp.timestamp(),
- 'repository': entry.repository,
- 'commit': entry.commit,
- 'message': entry.msg}
diff --git a/src_py/hatter/util.py b/src_py/hatter/util.py
deleted file mode 100644
index a4fff55..0000000
--- a/src_py/hatter/util.py
+++ /dev/null
@@ -1,147 +0,0 @@
-import collections
-import sys
-import contextlib
-import asyncio
-import datetime
-import sqlite3
-
-
-def namedtuple(name, *props):
- """Create documented namedtuple
-
- Args:
- name (Union[str,Tuple[str,str]]):
- named tuple's name or named tuple's name with documentation
- props (Sequence[Union[str,Tuple[str,str]]]):
- named tuple' properties with optional documentation
-
- Returns:
- class implementing collections.namedtuple
-
- """
- props = [(i, None) if isinstance(i, str) else i for i in props]
- cls = collections.namedtuple(name if isinstance(name, str) else name[0],
- [i[0] for i in props])
- if not isinstance(name, str) and name[1]:
- cls.__doc__ = name[1]
- for k, v in props:
- if v:
- getattr(cls, k).__doc__ = v
- try:
- cls.__module__ = sys._getframe(1).f_globals.get('__name__', '__main__')
- except (AttributeError, ValueError):
- pass
- return cls
-
-
-def run_until_complete_without_interrupt(future):
- """Run event loop until future or coroutine is done
-
- Args:
- future (Awaitable): future or coroutine
-
- Returns:
- Any: provided future's result
-
- KeyboardInterrupt is suppressed (while event loop is running) and is mapped
- to single cancelation of running task. If multipple KeyboardInterrupts
- occur, task is cancelled only once.
-
- """
- async def ping_loop():
- with contextlib.suppress(asyncio.CancelledError):
- while True:
- await asyncio.sleep(1)
-
- task = asyncio.ensure_future(future)
- if sys.platform == 'win32':
- ping_loop_task = asyncio.ensure_future(ping_loop())
- with contextlib.suppress(KeyboardInterrupt):
- asyncio.get_event_loop().run_until_complete(task)
- asyncio.get_event_loop().call_soon(task.cancel)
- if sys.platform == 'win32':
- asyncio.get_event_loop().call_soon(ping_loop_task.cancel)
- while not task.done():
- with contextlib.suppress(KeyboardInterrupt):
- asyncio.get_event_loop().run_until_complete(task)
- if sys.platform == 'win32':
- while not ping_loop_task.done():
- with contextlib.suppress(KeyboardInterrupt):
- asyncio.get_event_loop().run_until_complete(ping_loop_task)
- return task.result()
-
-
-def monkeypatch_sqlite3():
- """Monkeypatch sqlite timestamp converter"""
-
- def _sqlite_convert_timestamp(val):
- datepart, timetzpart = val.split(b" ")
- if b"+" in timetzpart:
- tzsign = 1
- timepart, tzpart = timetzpart.split(b"+")
- elif b"-" in timetzpart:
- tzsign = -1
- timepart, tzpart = timetzpart.split(b"-")
- else:
- timepart, tzpart = timetzpart, None
- year, month, day = map(int, datepart.split(b"-"))
- timepart_full = timepart.split(b".")
- hours, minutes, seconds = map(int, timepart_full[0].split(b":"))
- if len(timepart_full) == 2:
- microseconds = int('{:0<6.6}'.format(timepart_full[1].decode()))
- else:
- microseconds = 0
- if tzpart:
- tzhours, tzminutes = map(int, tzpart.split(b":"))
- tz = datetime.timezone(
- tzsign * datetime.timedelta(hours=tzhours, minutes=tzminutes))
- else:
- tz = None
-
- val = datetime.datetime(year, month, day, hours, minutes, seconds,
- microseconds, tz)
- return val
-
- sqlite3.register_converter("timestamp", _sqlite_convert_timestamp)
-
-
-class RegisterCallbackHandle(collections.namedtuple(
- 'RegisterCallbackHandle', ['cancel'])):
- """Handle used for canceling callback registration
-
- Attributes:
- cancel (Callable[[],None]): cancel registered callback
-
- """
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- self.cancel()
-
-
-class CallbackRegistry:
- """Callback registry"""
-
- def __init__(self):
- self._cbs = []
-
- def register(self, cb):
- """Register callback
-
- Args:
- cb (Callable): callback
-
- Returns:
- RegisterCallbackHandle
-
- """
- self.cbs.append(cb)
- return RegisterCallbackHandle(lambda: self.cbs.remove(cb))
-
- def notify(self, *args, **kwargs):
- """Notify all registered callbacks"""
-
- for cb in self._cbs:
- cb(*args, **kwargs)