diff options
Diffstat (limited to 'src_py')
| -rw-r--r-- | src_py/hatter/executor.py | 185 | ||||
| -rw-r--r-- | src_py/hatter/vm.py | 105 |
2 files changed, 185 insertions, 105 deletions
diff --git a/src_py/hatter/executor.py b/src_py/hatter/executor.py new file mode 100644 index 0000000..86c8e56 --- /dev/null +++ b/src_py/hatter/executor.py @@ -0,0 +1,185 @@ +import tempfile +import pathlib +import tarfile +import subprocess +import io +import time +import contextlib +import yaml +import libvirt +import paramiko + +import hatter.json_validator + + +def run(repo_path, commit='HEAD', archive_name='hatter_archive'): + archive_file_name = archive_name + '.tar.gz' + with tempfile.TemporaryDirectory() as tempdir: + archive_path = pathlib.Path(tempdir) / archive_file_name + _git_archive(repo_path, commit, archive_path) + conf = _load_conf(archive_path) + with contextlib.closing(_VM(conf['vm'])) as vm: + with contextlib.closing(_SSH(conf['ssh'], vm.address)) as ssh: + ssh.upload(archive_path, archive_file_name) + ssh.execute('rm -rf {} {}'.format(archive_file_name, + archive_name)) + ssh.upload(archive_path, archive_file_name) + ssh.execute('mkdir {}'.format(archive_name)) + ssh.execute('tar xf {} -C {}'.format(archive_file_name, + archive_name)) + for script in conf['scripts']: + ssh.execute(script, archive_name) + + +class _VM: + + def __init__(self, conf): + self._conn = None + self._domain = None + self._temp_snapshot = None + self._address = None + try: + self._conn = _libvirt_connect(conf.get('uri', 'qemu:///system')) + self._domain = _libvirt_get_domain(self._conn, conf['domain']) + if self._domain.isActive(): + self._domain.destroy() + self._temp_snapshot = _libvirt_create_temp_snapshot( + self._domain, conf.get('temp_snapshot', 'temp_hatter')) + if 'snapshot' in conf: + _libvirt_revert_snapshot(self._domain, conf['snapshot']) + _libvirt_start_domain(self._domain) + for _ in range(conf('get_address_retry_count', 10)): + self._address = _libvirt_get_address(self._domain) + if self._address: + return + time.sleep(conf.get('get_address_delay', 5)) + raise Exception('ip addess not detected') + except Exception: + self.close() + raise + + @property + def address(self): + return self._address + + def close(self): + if self._domain: + self._domain.destroy() + if self._domain and self._temp_snapshot: + self._domain.revertToSnapshot(self._temp_snapshot) + if self._temp_snapshot: + self._temp_snapshot.delete() + if self._conn: + self._conn.close() + self._temp_snapshot = None + self._domain = None + self._conn = None + self._address = None + + +class _SSH: + + def __init__(self, conf, address): + self._conn = paramiko.SSHClient() + self._conn.set_missing_host_key_policy(paramiko.AutoAddPolicy) + for _ in range(conf.get('connect_retry_count', 10)): + try: + self._conn.connect( + address, + username=conf['username'], password=conf['password'], + timeout=conf.get('connect_timeout', 1), + auth_timeout=conf.get('connect_timeout', 1)) + return + except Exception as e: + time.sleep(conf.get('connect_delay', 5)) + raise Exception('could not connect to {}'.format(address)) + + def close(self): + if self._conn: + self._conn.close() + self._conn = None + + def upload(self, src_path, dst_path): + with contextlib.closing(self._conn.open_sftp()) as sftp: + sftp.put(str(src_path), str(dst_path)) + + def execute(self, cmd, cwd='.'): + with contextlib.closing(self._conn.invoke_shell()) as shell: + shell.set_combine_stderr(True) + shell.exec_command('cd {} && {}'.format(cwd, cmd)) + with contextlib.closing(shell.makefile()) as f: + data = f.read() + # TODO remove print + print('>> ', cmd) + print(data) + exit_code = shell.recv_exit_status() + if exit_code > 0: + raise Exception('command exit code is {}'.format(exit_code)) + + +def _load_conf(archive_path): + with tarfile.open(archive_path) as archive: + f = io.TextIOWrapper(archive.extractfile('.hatter.yml'), + encoding='utf-8') + conf = yaml.safe_load(f) + hatter.json_validator.validate(conf, 'hatter://project.yaml#') + return conf + + +def _git_archive(repo_path, commit, output_path): + result = subprocess.run( + ['git', 'archive', '--format=tar.gz', + '--outfile={}'.format(str(output_path)), + '--remote={}'.format(repo_path), + commit], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + if result.returncode: + raise Exception("could not archive {} from {}".format(commit, + repo_path)) + + +def _libvirt_connect(uri): + conn = libvirt.open(uri) + if not conn: + raise Exception('could not open connection to {}'.format(uri)) + return conn + + +def _libvirt_get_domain(conn, domain_name): + domain = conn.lookupByName(domain_name) + if not domain: + raise Exception('domain {} not available'.format(domain_name)) + return domain + + +def _libvirt_start_domain(domain): + if domain.create(): + raise Exception('could not run vm') + + +def _libvirt_create_temp_snapshot(domain, temp_snapshot_name): + temp_snapshot = domain.snapshotLookupByName(temp_snapshot_name) + if temp_snapshot: + temp_snapshot.delete() + temp_snapshot = domain.snapshotCreateXML( + "<domainsnapshot><name>{}</name></domainsnapshot>".format( + temp_snapshot_name)) + if not temp_snapshot: + raise Exception('could not create snapshot {}'.format( + temp_snapshot_name)) + return temp_snapshot + + +def _libvirt_revert_snapshot(domain, snapshot_name): + snapshot = domain.snapshotLookupByName(snapshot_name) + if not snapshot: + raise Exception('snapshot {} not available'.format(snapshot_name)) + if domain.revertToSnapshot(snapshot): + raise Exception('could not revert snapshot {}'.format(snapshot_name)) + + +def _libvirt_get_address(domain): + addresses = domain.interfaceAddresses(0) + for i in addresses.values(): + for j in i.get('addrs', []): + return j.get('addr') diff --git a/src_py/hatter/vm.py b/src_py/hatter/vm.py deleted file mode 100644 index 87e71b9..0000000 --- a/src_py/hatter/vm.py +++ /dev/null @@ -1,105 +0,0 @@ -import time -import libvirt - - -class VM: - - def __init__(self, domain_name, uri='qemu:///system', snapshot_name=None, - temp_snapshot_name='temp_hatter', address_retry_count=10, - address_delay=5): - self._domain_name = domain_name - self._uri = uri - self._snapshot_name = snapshot_name - self._temp_snapshot_name = temp_snapshot_name - self._address_retry_count = address_retry_count - self._address_delay = address_delay - self._conn = None - self._domain = None - self._temp_snapshot = None - self._address = None - - @property - def address(self): - return self._address - - def start(self): - try: - self._conn = _libvirt_connect(self._uri) - self._domain = _libvirt_get_domain(self._conn, self._domain_name) - if self._domain.isActive(): - self._domain.destroy() - self._temp_snapshot = _libvirt_create_temp_snapshot( - self._domain, self._temp_snapshot_name) - if self._snapshot_name: - _libvirt_revert_snapshot(self._domain, self._snapshot_name) - _libvirt_start_domain(self._domain) - for _ in range(self._address_retry_count): - self._address = _libvirt_get_address(self._domain) - if self._address: - return - time.sleep(self._address_delay) - raise Exception('ip addess not detected') - except Exception: - self.stop() - raise - - def stop(self): - if self._domain: - self._domain.destroy() - if self._domain and self._temp_snapshot: - self._domain.revertToSnapshot(self._temp_snapshot) - if self._temp_snapshot: - self._temp_snapshot.delete() - if self._conn: - self._conn.close() - self._temp_snapshot = None - self._domain = None - self._conn = None - self._address = None - - -def _libvirt_connect(uri): - conn = libvirt.open(uri) - if not conn: - raise Exception('could not open connection to {}'.format(uri)) - return conn - - -def _libvirt_get_domain(conn, domain_name): - domain = conn.lookupByName(domain_name) - if not domain: - raise Exception('domain {} not available'.format(domain_name)) - return domain - - -def _libvirt_start_domain(domain): - if domain.create(): - raise Exception('could not run vm') - - -def _libvirt_create_temp_snapshot(domain, temp_snapshot_name): - temp_snapshot = domain.snapshotLookupByName(temp_snapshot_name) - if temp_snapshot: - temp_snapshot.delete() - temp_snapshot = domain.snapshotCreateXML( - "<domainsnapshot><name>{}</name></domainsnapshot>".format( - temp_snapshot_name)) - if not temp_snapshot: - raise Exception('could not create snapshot {}'.format( - temp_snapshot_name)) - return temp_snapshot - - -def _libvirt_revert_snapshot(domain, snapshot_name): - snapshot = domain.snapshotLookupByName(snapshot_name) - if not snapshot: - raise Exception('snapshot {} not available'.format(snapshot_name)) - if domain.revertToSnapshot(snapshot): - raise Exception('could not revert snapshot {}'.format(snapshot_name)) - - -def _libvirt_get_address(domain): - addresses = domain.interfaceAddresses(0) - for i in addresses.values(): - for j in i.get('addrs', []): - return j.get('addr') |
