diff options
| author | bozo.kopic <bozo@kopic.xyz> | 2022-03-22 01:31:27 +0100 |
|---|---|---|
| committer | bozo.kopic <bozo@kopic.xyz> | 2022-03-22 01:31:27 +0100 |
| commit | cc4ba3b063f14943579ffbfe416828590f70ae0a (patch) | |
| tree | af2127920fb57603206ca670beb63b5d58650fb8 /src_py | |
| parent | c594b1fca854a7b9fb73d854a9830143cd1032fc (diff) | |
WIP major rewrite
Diffstat (limited to 'src_py')
| -rw-r--r-- | src_py/hatter/__init__.py | 0 | ||||
| -rw-r--r-- | src_py/hatter/__main__.py | 8 | ||||
| -rw-r--r-- | src_py/hatter/backend.py | 171 | ||||
| -rw-r--r-- | src_py/hatter/common.py | 10 | ||||
| -rw-r--r-- | src_py/hatter/executor.py | 197 | ||||
| -rw-r--r-- | src_py/hatter/main.py | 189 | ||||
| -rw-r--r-- | src_py/hatter/server.py | 163 | ||||
| -rw-r--r-- | src_py/hatter/util.py | 147 |
8 files changed, 133 insertions, 752 deletions
diff --git a/src_py/hatter/__init__.py b/src_py/hatter/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src_py/hatter/__init__.py diff --git a/src_py/hatter/__main__.py b/src_py/hatter/__main__.py new file mode 100644 index 0000000..065fe88 --- /dev/null +++ b/src_py/hatter/__main__.py @@ -0,0 +1,8 @@ +import sys + +from hatter.main import main + + +if __name__ == '__main__': + sys.argv[0] = 'hatter' + sys.exit(main()) diff --git a/src_py/hatter/backend.py b/src_py/hatter/backend.py deleted file mode 100644 index 8b78570..0000000 --- a/src_py/hatter/backend.py +++ /dev/null @@ -1,171 +0,0 @@ -import sqlite3 -import datetime -import threading -import logging -import concurrent.futures -import asyncio - -from hatter import util -from hatter import executor - - -util.monkeypatch_sqlite3() - - -LogEntry = util.namedtuple('LogEntry', - ['timestamp', 'datetime.datetime: timestamp'], - ['repository', 'str: repository'], - ['commit', 'str: commit'], - ['msg', 'str: message']) - -Job = util.namedtuple('Job', - ['id', 'int: id'], - ['timestamp', 'datetime.datetime: timestamp'], - ['repository', 'str: repository'], - ['commit', 'str: commit']) - - -class Backend: - - def __init__(self, db_path, repositories): - self._repositories = repositories - self._next_job_id = 0 - self._active = None - self._queue = [] - self._active_change_cbs = util.CallbackRegistry() - self._queue_change_cbs = util.CallbackRegistry() - self._log_change_cbs = util.CallbackRegistry() - self._cv = asyncio.Condition() - self._db = _DB(db_path) - self._executor = concurrent.futures.ThreadPoolExecutor() - self._run_loop_future = asyncio.ensure_future(self._run_loop()) - - @property - def repositories(self): - return self._repositories - - @property - def active(self): - return self._active - - @property - def queue(self): - return self._queue - - def register_active_change_cb(self, cb): - return self._active_change_cbs.register(cb) - - def register_queue_change_cb(self, cb): - return self._queue_change_cbs.register(cb) - - def register_log_change_cb(self, cb): - return self._log_change_cbs.register(cb) - - async def async_close(self): - self._run_loop_future.cancel() - await self._run_loop_future - - async def query_log(self, offset, limit): - return await asyncio.get_event_loop().run_in_executor( - self._executor, self._db.query, offset, limit) - - async def add_job(self, repository, commit): - job = Job(id=self._next_job_id, - timestamp=datetime.datetime.now(datetime.timezone.utc), - repository=repository, - commit=commit) - self._next_job_id += 1 - with await self._cv: - self._queue.append(job) - self._cv.notify_all() - self._queue_change_cbs.notify() - - async def _run_loop(self): - log = logging.getLogger('hatter.project') - while True: - with await self._cv: - while not self._queue: - await self._cv.wait() - self._active = self._queue_change_cbs.pop(0) - self._queue_change_cbs.notify() - self._active_change_cbs.notify() - - handler = _LogHandler( - self._db, self._active.repository, self._active.commit) - log.addHandler(handler) - try: - await asyncio.get_event_loop().run_in_executor( - self._executor, executor.run, - log, self._active.repository, self._active.commit) - except asyncio.CancelledError: - break - except Exception as e: - log.error("%s", e, exc_info=True) - finally: - log.removeHandler(handler) - self._active = None - self._active_change_cbs.notify() - - -class _LogHandler(logging.Handler): - - def __init__(self, db, repository, commit): - super().__init__() - self._db = db - self._repository - self._commit = commit - - def emit(self, record): - self._db.add( - timestamp=datetime.datetime.fromtimestamp( - record.created, datetime.timezone.utc), - repository=self._repository, - commit=self._commit, - msg=record.getMessage()) - - -class _DB: - - def __init__(self, db_path): - db_path.parent.mkdir(exist_ok=True) - self._db = sqlite3.connect('file:{}?nolock=1'.format(db_path), - uri=True, - isolation_level=None, - detect_types=sqlite3.PARSE_DECLTYPES) - self._db.executescript("CREATE TABLE IF NOT EXISTS log (" - "timestamp TIMESTAMP, " - "repository TEXT, " - "commit TEXT, " - "msg TEXT)") - self._db.commit() - self._lock = threading.Lock() - - def close(self): - with self._lock: - self._db.close() - - def add(self, timestamp, repository, commit, msg): - with self._lock: - self._db.execute( - "INSERT INTO log VALUES " - "(:timestamp, :repository, :commit, :msg)", - {'timestamp': timestamp, - 'repository': repository, - 'commit': commit, - 'msg': msg}) - - def query(self, offset, limit): - with self._lock: - c = self._db.execute( - "SELECT rowid, * FROM log ORDER BY rowid DESC " - "LIMIT :limit OFFSET :offset", - {'limit': limit, 'offset': offset}) - try: - result = c.fetchall() - except Exception as e: - result = [] - return [LogEntry(timestamp=i[1], - repository=i[2], - commit=i[3], - msg=i[4]) - for i in result] diff --git a/src_py/hatter/common.py b/src_py/hatter/common.py new file mode 100644 index 0000000..cf99bb9 --- /dev/null +++ b/src_py/hatter/common.py @@ -0,0 +1,10 @@ +from pathlib import Path + +from hat import json + + +package_path: Path = Path(__file__).parent + +json_schema_repo: json.SchemaRepository = json.SchemaRepository( + json.json_schema_repo, + json.SchemaRepository.from_json(package_path / 'json_schema_repo.json')) diff --git a/src_py/hatter/executor.py b/src_py/hatter/executor.py deleted file mode 100644 index 2d1ae2a..0000000 --- a/src_py/hatter/executor.py +++ /dev/null @@ -1,197 +0,0 @@ -import tempfile -import pathlib -import tarfile -import subprocess -import io -import time -import contextlib -import yaml -import libvirt -import paramiko - -import hatter.json_validator - - -def run(log, repo_path, commit='HEAD', archive_name='hatter_archive'): - log.info('starting executor for repository {} ({})'.format(repo_path, - commit)) - t_begin = time.monotonic() - archive_file_name = archive_name + '.tar.gz' - with tempfile.TemporaryDirectory() as tempdir: - archive_path = pathlib.Path(tempdir) / archive_file_name - log.info('fetching remote repository') - _git_archive(repo_path, commit, archive_path) - log.info('loading project configuration') - conf = _load_conf(archive_path) - for i in conf: - log.info('starting virtual machine') - with contextlib.closing(_VM(i['vm'])) as vm: - log.info('creating SSH connection') - with contextlib.closing(_SSH(i['ssh'], vm.address)) as ssh: - log.info('transfering repository to virtual machine') - ssh.execute('rm -rf {} {}'.format(archive_file_name, - archive_name)) - ssh.upload(archive_path, archive_file_name) - ssh.execute('mkdir {}'.format(archive_name)) - ssh.execute('tar xf {} -C {}'.format(archive_file_name, - archive_name)) - log.info('executing scripts') - for script in i['scripts']: - ssh.execute(script, archive_name, log) - t_end = time.monotonic() - log.info('executor finished (duration: {}s)'.format(t_end - t_begin)) - - -class _VM: - - def __init__(self, conf): - self._conn = None - self._domain = None - self._temp_snapshot = None - self._address = None - try: - self._conn = _libvirt_connect(conf.get('uri', 'qemu:///system')) - self._domain = _libvirt_get_domain(self._conn, conf['domain']) - if self._domain.isActive(): - self._domain.destroy() - self._temp_snapshot = _libvirt_create_temp_snapshot( - self._domain, conf.get('temp_snapshot', 'temp_hatter')) - if 'snapshot' in conf: - _libvirt_revert_snapshot(self._domain, conf['snapshot']) - _libvirt_start_domain(self._domain) - for _ in range(conf('get_address_retry_count', 10)): - self._address = _libvirt_get_address(self._domain) - if self._address: - return - time.sleep(conf.get('get_address_delay', 5)) - raise Exception('ip addess not detected') - except Exception: - self.close() - raise - - @property - def address(self): - return self._address - - def close(self): - if self._domain: - self._domain.destroy() - if self._domain and self._temp_snapshot: - self._domain.revertToSnapshot(self._temp_snapshot) - if self._temp_snapshot: - self._temp_snapshot.delete() - if self._conn: - self._conn.close() - self._temp_snapshot = None - self._domain = None - self._conn = None - self._address = None - - -class _SSH: - - def __init__(self, conf, address): - self._conn = paramiko.SSHClient() - self._conn.set_missing_host_key_policy(paramiko.AutoAddPolicy) - for _ in range(conf.get('connect_retry_count', 10)): - try: - self._conn.connect( - address, - username=conf['username'], password=conf['password'], - timeout=conf.get('connect_timeout', 1), - auth_timeout=conf.get('connect_timeout', 1)) - return - except Exception as e: - time.sleep(conf.get('connect_delay', 5)) - raise Exception('could not connect to {}'.format(address)) - - def close(self): - if self._conn: - self._conn.close() - self._conn = None - - def upload(self, src_path, dst_path): - with contextlib.closing(self._conn.open_sftp()) as sftp: - sftp.put(str(src_path), str(dst_path)) - - def execute(self, cmd, cwd='.', log=None): - if log: - log.info('executing command: {}'.format(cmd)) - with contextlib.closing(self._conn.invoke_shell()) as shell: - shell.set_combine_stderr(True) - shell.exec_command('cd {} && {}'.format(cwd, cmd)) - with contextlib.closing(shell.makefile()) as f: - data = f.read() - if log: - log.info('command output: {}'.format(data)) - exit_code = shell.recv_exit_status() - if exit_code > 0: - raise Exception('command exit code is {}'.format(exit_code)) - - -def _load_conf(archive_path): - with tarfile.open(archive_path) as archive: - f = io.TextIOWrapper(archive.extractfile('.hatter.yml'), - encoding='utf-8') - conf = yaml.safe_load(f) - hatter.json_validator.validate(conf, 'hatter://project.yaml#') - return conf - - -def _git_archive(repo_path, commit, output_path): - result = subprocess.run( - ['git', 'archive', '--format=tar.gz', - '--outfile={}'.format(str(output_path)), - '--remote={}'.format(repo_path), - commit], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - if result.returncode: - raise Exception("could not archive {} from {}".format(commit, - repo_path)) - - -def _libvirt_connect(uri): - conn = libvirt.open(uri) - if not conn: - raise Exception('could not open connection to {}'.format(uri)) - return conn - - -def _libvirt_get_domain(conn, domain_name): - domain = conn.lookupByName(domain_name) - if not domain: - raise Exception('domain {} not available'.format(domain_name)) - return domain - - -def _libvirt_start_domain(domain): - if domain.create(): - raise Exception('could not run vm') - - -def _libvirt_create_temp_snapshot(domain, temp_snapshot_name): - temp_snapshot = domain.snapshotLookupByName(temp_snapshot_name) - if temp_snapshot: - temp_snapshot.delete() - temp_snapshot = domain.snapshotCreateXML( - "<domainsnapshot><name>{}</name></domainsnapshot>".format( - temp_snapshot_name)) - if not temp_snapshot: - raise Exception('could not create snapshot {}'.format( - temp_snapshot_name)) - return temp_snapshot - - -def _libvirt_revert_snapshot(domain, snapshot_name): - snapshot = domain.snapshotLookupByName(snapshot_name) - if not snapshot: - raise Exception('snapshot {} not available'.format(snapshot_name)) - if domain.revertToSnapshot(snapshot): - raise Exception('could not revert snapshot {}'.format(snapshot_name)) - - -def _libvirt_get_address(domain): - addresses = domain.interfaceAddresses(0) - for i in addresses.values(): - for j in i.get('addrs', []): - return j.get('addr') diff --git a/src_py/hatter/main.py b/src_py/hatter/main.py index 71245a2..f1046a5 100644 --- a/src_py/hatter/main.py +++ b/src_py/hatter/main.py @@ -1,74 +1,115 @@ -import sys
-import asyncio
-import argparse
-import pdb
-import yaml
-import logging.config
-import atexit
-import pkg_resources
-import pathlib
-
-import hatter.json_validator
-from hatter import util
-from hatter.backend import Backend
-from hatter.server import create_web_server
-
-
-def main():
- args = _create_parser().parse_args()
-
- with open(args.conf, encoding='utf-8') as conf_file:
- conf = yaml.safe_load(conf_file)
- hatter.json_validator.validate(conf, 'hatter://server.yaml#')
-
- if 'log' in conf:
- logging.config.dictConfig(conf['log'])
-
- if args.web_path:
- web_path = args.web_path
- else:
- atexit.register(pkg_resources.cleanup_resources)
- web_path = pkg_resources.resource_filename('hatter', 'web')
-
- util.run_until_complete_without_interrupt(async_main(conf, web_path))
-
-
-async def async_main(conf, web_path):
- backend = None
- web_server = None
- try:
- backend = Backend(pathlib.Path(conf.get('db_path', 'hatter.db')),
- conf['repositories'])
- web_server = await create_web_server(
- backend, conf.get('host', '0.0.0.0'), conf.get('port', 24000),
- conf.get('webhook_path', '/webhook'), web_path)
- await asyncio.Future()
- except asyncio.CancelledError:
- pass
- except Exception as e:
- pdb.set_trace()
- raise
- finally:
- if web_server:
- await web_server.async_close()
- if backend:
- await backend.async_close()
- await asyncio.sleep(0.5)
-
-
-def _create_parser():
- parser = argparse.ArgumentParser(prog='hatter')
- parser.add_argument(
- '--web-path', default=None, metavar='path', dest='web_path',
- help="web ui directory path")
-
- named_arguments = parser.add_argument_group('required named arguments')
- named_arguments.add_argument(
- '-c', '--conf', required=True, metavar='path', dest='conf',
- help='configuration path')
-
- return parser
-
-
-if __name__ == '__main__':
- sys.exit(main())
+from pathlib import Path +import asyncio +import contextlib +import logging.config +import sys +import tempfile +import typing +import subprocess + +import appdirs +import click + +from hat import aio +from hat import json + +from hatter import common + + +user_config_dir: Path = Path(appdirs.user_config_dir('hatter')) +user_data_dir: Path = Path(appdirs.user_data_dir('hatter')) + +default_conf_path: Path = user_config_dir / 'server.yaml' +default_db_path: Path = user_data_dir / 'hatter.db' + +ssh_key_path: typing.Optional[Path] = None + + +@click.group() +@click.option('--log-level', + default='INFO', + type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', + 'DEBUG', 'NOTSET']), + help="log level") +@click.option('--ssh-key', default=None, metavar='PATH', type=Path, + help="private key used for ssh authentication") +def main(log_level: str, + ssh_key: typing.Optional[Path]): + global ssh_key_path + ssh_key_path = ssh_key + + logging.config.dictConfig({ + 'version': 1, + 'formatters': { + 'console': { + 'format': "[%(asctime)s %(levelname)s %(name)s] %(message)s"}}, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'formatter': 'console', + 'level': log_level}}, + 'root': { + 'level': log_level, + 'handlers': ['console']}, + 'disable_existing_loggers': False}) + + +@main.command() +@click.argument('url', required=True) +@click.argument('branch', required=False, default='master') +@click.argument('action', required=False, default='.hatter.yaml') +def execute(url: str, + branch: str, + action: str): + with tempfile.TemporaryDirectory() as repo_dir: + repo_dir = Path(repo_dir) + + subprocess.run(['git', 'clone', '-q', '--depth', '1', + '-b', branch, url, str(repo_dir)], + check=True) + + conf = json.decode_file(repo_dir / '.hatter.yaml') + common.json_schema_repo.validate('hatter://action.yaml#', conf) + + image = conf['image'] + command = conf['command'] + subprocess.run(['podman', 'run', '-i', '--rm', + '-v', f'{repo_dir}:/hatter', + image, '/bin/sh'], + input=f'set -e\ncd /hatter\n{command}\n', + encoding='utf-8', + check=True) + + +@main.command() +@click.option('--host', default='0.0.0.0', + help="listening host name (default 0.0.0.0)") +@click.option('--port', default=24000, type=int, + help="listening TCP port (default 24000)") +@click.option('--conf', default=default_conf_path, metavar='PATH', type=Path, + help="configuration defined by hatter://server.yaml# " + "(default $XDG_CONFIG_HOME/hatter/server.yaml)") +@click.option('--db', default=default_db_path, metavar='PATH', type=Path, + help="sqlite database path " + "(default $XDG_CONFIG_HOME/hatter/hatter.db") +def server(host: str, + port: int, + conf: Path, + db: Path): + conf = json.decode_file(conf) + common.json_schema_repo.validate('hatter://server.yaml#', conf) + + with contextlib.suppress(asyncio.CancelledError): + aio.run_asyncio(async_server(host, port, conf, db)) + + +async def async_server(host: str, + port: int, + conf: json.Data, + db_path: Path): + pass + + +if __name__ == '__main__': + sys.argv[0] = 'hatter' + main() diff --git a/src_py/hatter/server.py b/src_py/hatter/server.py deleted file mode 100644 index baeffac..0000000 --- a/src_py/hatter/server.py +++ /dev/null @@ -1,163 +0,0 @@ -import asyncio -import json -import aiohttp.web - -from hatter import util -import hatter.json_validator - - -async def create_web_server(backend, host, port, webhook_path, web_path): - srv = WebServer() - srv._backend = backend - srv._app = aiohttp.web.Application() - srv._app.router.add_route( - 'GET', '/', lambda req: aiohttp.web.HTTPFound('/index.html')) - srv._app.router.add_route('*', '/ws', srv._ws_handler) - srv._app.router.add_route('POST', webhook_path, srv._webhook_handler) - srv._app.router.add_static('/', web_path) - srv._app_handler = srv._app.make_handler() - srv._srv = await asyncio.get_event_loop().create_server( - srv._app_handler, host=host, port=port) - return srv - - -class WebServer: - - async def async_close(self): - self._srv.close() - await self._srv.wait_closed() - await self._app.shutdown() - await self._app_handler.finish_connections(0) - await self._app.cleanup() - - async def _ws_handler(self, request): - ws = aiohttp.web.WebSocketResponse() - await ws.prepare(request) - client = _Client(self._backend, ws) - await client.run() - return ws - - async def _webhook_handler(self, request): - try: - if not ({'X-Gitlab-Event', 'X-GitHub-Event'} & - set(request.headers.keys())): - raise Exception('unsupported webhook request') - body = await request.read() - data = json.loads(body) - req = _parse_webhook_request(request.headers, data) - for commit in req.commits: - self._backend.add_job(req.url, commit) - except Exception: - pass - return aiohttp.web.Response() - - -_WebhookRequest = util.namedtuple('_WebhookRequest', 'url', 'commits') - - -def _parse_webhook_request(headers, data): - if headers.get('X-Gitlab-Event') == 'Push Hook': - url = data['repository']['git_http_url'] - commits = [commit['id'] for commit in data['commits']] - elif headers.get('X-GitHub-Event') == 'push': - url = data['repository']['clone_url'] - commits = [commit['id'] for commit in data['commits']] - else: - raise Exception('unsupported webhook event') - return _WebhookRequest(url, commits) - - -class _Client: - - def __init__(self, backend, ws): - self._backend = backend - self._ws = ws - self._log_offset = 0 - self._log_limit = 0 - self._log_entries = [] - self._active_job = None - self._job_queue = [] - - async def run(self): - self._active_job = self._backend.active - self._job_queue = list(self._backend.queue) - with self._backend.register_active_change_cb(self._on_active_change): - with self._backend.register_queue_change_cb(self._on_queue_change): - with self._backend.register_log_change_cb(self._on_log_change): - try: - self._send_repositories() - self._send_active_job() - self._send_job_queue() - self._send_log_entries() - while True: - msg = await self._ws.receive() - if self._ws.closed: - break - if msg.type != aiohttp.WSMsgType.TEXT: - continue - json_msg = json.loads(msg.data, encoding='utf-8') - hatter.json_validator.validate(json_msg, 'hatter://message.yaml#/definitions/client_message') # NOQA - await self._process_msg(json_msg) - except Exception as e: - print('>>>', e) - - async def _process_msg(self, msg): - if msg['type'] == 'set_log': - self._log_offset = msg['log_offset'] - self._log_limit = msg['log_limit'] - await self._update_log() - elif msg['type'] == 'add_job': - await self._backend.add_job(msg['repository'], msg['commit']) - - def _on_active_change(self): - if self._active_job != self._backend.active: - self._active_job = self._backend.active - self._send_active_job() - - def _on_queue_change(self): - if self._job_queue != self._backend.job_queue: - self._job_queue = list(self._backend.job_queue) - self._send_job_queue() - - def _on_log_change(self): - asyncio.ensure_future(self._update_log()) - - async def _update_log(self, offset, limit): - log_entries = await self._backend.query_log(offset, limit) - if log_entries != self._log_entries: - self._log_entries = log_entries - self._send_log_entries() - - def _send_repositories(self): - self._ws.send_str(json.dumps({ - 'type': 'repositories', - 'repositories': self._backend.repositories})) - - def _send_active_job(self): - self._ws.send_str(json.dumps({ - 'type': 'active_job', - 'job': _job_to_json(self._active_job)})) - - def _send_job_queue(self): - self._ws.send_str(json.dumps({ - 'type': 'job_queue', - 'jobs': [_job_to_json(i) for i in self._job_queue]})) - - def _send_log_entries(self): - self._ws.send_str(json.dumps({ - 'type': 'log_entries', - 'entries': [_log_entry_to_json(i) for i in self._log_entries]})) - - -def _job_to_json(job): - return {'id': job.id, - 'timestamp': job.timestamp.timestamp(), - 'repository': job.repository, - 'commit': job.commit} - - -def _log_entry_to_json(entry): - return {'timestamp': entry.timestamp.timestamp(), - 'repository': entry.repository, - 'commit': entry.commit, - 'message': entry.msg} diff --git a/src_py/hatter/util.py b/src_py/hatter/util.py deleted file mode 100644 index a4fff55..0000000 --- a/src_py/hatter/util.py +++ /dev/null @@ -1,147 +0,0 @@ -import collections -import sys -import contextlib -import asyncio -import datetime -import sqlite3 - - -def namedtuple(name, *props): - """Create documented namedtuple - - Args: - name (Union[str,Tuple[str,str]]): - named tuple's name or named tuple's name with documentation - props (Sequence[Union[str,Tuple[str,str]]]): - named tuple' properties with optional documentation - - Returns: - class implementing collections.namedtuple - - """ - props = [(i, None) if isinstance(i, str) else i for i in props] - cls = collections.namedtuple(name if isinstance(name, str) else name[0], - [i[0] for i in props]) - if not isinstance(name, str) and name[1]: - cls.__doc__ = name[1] - for k, v in props: - if v: - getattr(cls, k).__doc__ = v - try: - cls.__module__ = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - pass - return cls - - -def run_until_complete_without_interrupt(future): - """Run event loop until future or coroutine is done - - Args: - future (Awaitable): future or coroutine - - Returns: - Any: provided future's result - - KeyboardInterrupt is suppressed (while event loop is running) and is mapped - to single cancelation of running task. If multipple KeyboardInterrupts - occur, task is cancelled only once. - - """ - async def ping_loop(): - with contextlib.suppress(asyncio.CancelledError): - while True: - await asyncio.sleep(1) - - task = asyncio.ensure_future(future) - if sys.platform == 'win32': - ping_loop_task = asyncio.ensure_future(ping_loop()) - with contextlib.suppress(KeyboardInterrupt): - asyncio.get_event_loop().run_until_complete(task) - asyncio.get_event_loop().call_soon(task.cancel) - if sys.platform == 'win32': - asyncio.get_event_loop().call_soon(ping_loop_task.cancel) - while not task.done(): - with contextlib.suppress(KeyboardInterrupt): - asyncio.get_event_loop().run_until_complete(task) - if sys.platform == 'win32': - while not ping_loop_task.done(): - with contextlib.suppress(KeyboardInterrupt): - asyncio.get_event_loop().run_until_complete(ping_loop_task) - return task.result() - - -def monkeypatch_sqlite3(): - """Monkeypatch sqlite timestamp converter""" - - def _sqlite_convert_timestamp(val): - datepart, timetzpart = val.split(b" ") - if b"+" in timetzpart: - tzsign = 1 - timepart, tzpart = timetzpart.split(b"+") - elif b"-" in timetzpart: - tzsign = -1 - timepart, tzpart = timetzpart.split(b"-") - else: - timepart, tzpart = timetzpart, None - year, month, day = map(int, datepart.split(b"-")) - timepart_full = timepart.split(b".") - hours, minutes, seconds = map(int, timepart_full[0].split(b":")) - if len(timepart_full) == 2: - microseconds = int('{:0<6.6}'.format(timepart_full[1].decode())) - else: - microseconds = 0 - if tzpart: - tzhours, tzminutes = map(int, tzpart.split(b":")) - tz = datetime.timezone( - tzsign * datetime.timedelta(hours=tzhours, minutes=tzminutes)) - else: - tz = None - - val = datetime.datetime(year, month, day, hours, minutes, seconds, - microseconds, tz) - return val - - sqlite3.register_converter("timestamp", _sqlite_convert_timestamp) - - -class RegisterCallbackHandle(collections.namedtuple( - 'RegisterCallbackHandle', ['cancel'])): - """Handle used for canceling callback registration - - Attributes: - cancel (Callable[[],None]): cancel registered callback - - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.cancel() - - -class CallbackRegistry: - """Callback registry""" - - def __init__(self): - self._cbs = [] - - def register(self, cb): - """Register callback - - Args: - cb (Callable): callback - - Returns: - RegisterCallbackHandle - - """ - self.cbs.append(cb) - return RegisterCallbackHandle(lambda: self.cbs.remove(cb)) - - def notify(self, *args, **kwargs): - """Notify all registered callbacks""" - - for cb in self._cbs: - cb(*args, **kwargs) |
