Loading tpm2-algtest-collect +205 −8 Original line number Diff line number Diff line #!env python3 #!/bin/env python3.10 import asyncio import asyncssh import bson import click import logging import os import random import re import sys from datetime import datetime from pathlib import Path from typing import Any, Optional, Type, TypeVar from typing import Any, Optional, Tuple, Type, TypeVar T = TypeVar('T') SCRIPT_BIN = os.path.dirname(os.path.realpath(__file__)) REPO_PATH = '/var/lib/tpm2-algtest-nspawn/' UNIX_SOCK = os.path.join(SCRIPT_BIN, "_sock") DEPOSITORY = '/mnt/local/tpm/' class Utils: Loading Loading @@ -180,6 +185,38 @@ class CompositeReporterFactory(ReporterFactory): *list(map(lambda f: f.get_for(machine), self.factories))) class AlgtestStatus: def __init__(self): self._data = {} self._cond = asyncio.Condition() self._n = 0 async def _notify(self): async with self._cond: self._n += 1 self._cond.notify_all() async def add(self, machine: str) -> None: self._data[machine] = 'started' await self._notify() async def set(self, machine: str, status: str) -> None: self._data[machine] = status await self._notify() def machines(self) -> list[str]: return sorted(self._data.keys()) def data(self) -> dict[str, str]: return self._data async def data_wait(self, n: int = -1) -> Tuple[dict[str, str], int]: async with self._cond: while n == self._n: await self._cond.wait() return (self._data, self._n) class SSHReportErrSession(asyncssh.SSHClientSession[str]): def __init__(self, reporter: ReporterInterface): self.reporter = reporter Loading @@ -194,21 +231,27 @@ class Collector: @staticmethod async def _run_commands_with_connection( reporter: ReporterInterface, status: AlgtestStatus, machine: str, conn: asyncssh.SSHClientConnection, commands: list[str]) -> ReturnCode: for command in commands: for (index, command) in enumerate(commands): reporter.command(command) await status.set(machine, f"{index + 1}/{len(commands)}") chan, session = await conn.create_session( lambda: SSHReportErrSession(reporter), command) await chan.wait_closed() rc = ReturnCode(chan.get_returncode()) if not rc: reporter.failure(command, rc) await status.set(machine, f"failed ({index})") return rc await status.set(machine, "success") return ReturnCode.success() @staticmethod Loading @@ -221,27 +264,32 @@ class Collector: @staticmethod async def _run_commands(reporter: ReporterInterface, status: AlgtestStatus, machine: str, commands: list[str]) -> ReturnCode: try: await status.add(machine) reporter.status("Connecting") async with asyncssh.connect( machine, config=None, options=Collector._ssh_options() ) as conn: await status.set(machine, f"0/{len(commands)}") reporter.status("Connected") return await Collector._run_commands_with_connection( reporter, machine, conn, commands reporter, status, machine, conn, commands ) except Exception as ex: await status.set(machine, 'failed') reporter.ex(ex) return ReturnCode(255) @staticmethod def test() -> list[str]: return [ "echo Hello >&2; exit 1" "echo Hello >&2", "sleep $((5 + RANDOM % 10))", ] @staticmethod Loading @@ -249,6 +297,7 @@ class Collector: return [ f"test -d {REPO_PATH} || git clone {config['tpm2_tools_remote']} {REPO_PATH}", f"git -C {REPO_PATH} clean -xf && git -C {REPO_PATH} restore .", f"git -C {REPO_PATH} checkout main", f"git -C {REPO_PATH} pull --rebase", ] Loading @@ -270,6 +319,18 @@ class Collector: f"ps aux|fgrep nspawn >&2", ] @staticmethod async def _run_tasks(reporter_factory: ReporterFactory, machines: list[str], status: AlgtestStatus, server, commands: list[str]): await asyncio.gather( *[Collector._run_commands(reporter_factory.get_for(machine), status, machine, commands) for machine in machines] ) await server.stop() @staticmethod async def run(config: Config, machines: list[str], commands: list[str]) -> None: reporter_factory: Optional[ReporterFactory] = None Loading @@ -282,17 +343,75 @@ class Collector: else: reporter_factory = ReporterFactory(LogReporter) status = AlgtestStatus() server = UnixServer(UNIX_SOCK, status) await server.start() await asyncio.gather( *[Collector._run_commands(reporter_factory.get_for(machine), machine, commands) server.serve(), Collector._run_tasks(reporter_factory, machines, status, server, commands) ) @staticmethod async def _download_assets(machine: str) -> None: try: async with asyncssh.connect( machine, config=None, options=Collector._ssh_options()) as conn: path = os.path.join(REPO_PATH, '*/out-*.zip') await asyncssh.scp((conn, path), DEPOSITORY) except Exception as ex: print(f"{machine}: {ex}", file=sys.stderr) @staticmethod async def download(config: Config, machines: list[str]) -> None: await asyncio.gather( *[Collector._download_assets(machine) for machine in machines] ) class UnixServer: _status: Optional[AlgtestStatus] = None def __init__(self, path: str, status: AlgtestStatus): self.path = path UnixServer._status = status async def start(self): self.sock = await asyncio.start_unix_server(UnixServer._serve, path=self.path) logging.info(f"Started UNIX socket server: {self.path}") async def stop(self): self.sock.close() async def serve(self): try: await self.sock.serve_forever() except asyncio.exceptions.CancelledError: logging.info("Stopped UNIX socket server") @staticmethod async def _serve(rd: asyncio.StreamReader, wr: asyncio.StreamWriter) -> None: assert UnixServer._status is not None n = -1 try: while p := await UnixServer._status.data_wait(n): (data, n) = p wr.write(bson.dumps(data)) await wr.drain() except Exception as ex: print(ex) wr.close() def set_machines(machines: list[str]) -> list[str]: if not machines or (len(machines) == 1 and machines[0] == 'all'): return MACHINES if (len(machines) == 1 and machines[0] in ['nymfe', 'musa']): return list(filter(lambda host: host.startswith(machines[0]), MACHINES)) return list( map(lambda host: f"{host}" if "." in host else f"{host}.fi.muni.cz", machines) Loading @@ -309,10 +428,23 @@ def setup_logging() -> None: @click.group() @click.option('--debug', '--no-debug', default=False, is_flag=True) @click.option('--socket') @click.option('--repo-path') @click.pass_context def cli(ctx: dict[str, Any], debug: bool) -> None: def cli(ctx: dict[str, Any], debug: bool, socket: Optional[str], repo_path: Optional[str]) -> None: setup_logging() if socket is not None: if '/' not in socket: socket = os.path.join(SCRIPT_BIN, socket) global UNIX_SOCK UNIX_SOCK = socket if repo_path is not None: global REPO_PATH REPO_PATH = repo_path if debug: logging.getLogger().setLevel(logging.DEBUG) asyncssh.set_log_level(logging.DEBUG) Loading @@ -329,6 +461,14 @@ def test(ctx: click.Context, machines: list[str]) -> None: asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.test())) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context def download(ctx: click.Context, machines: list[str]) -> None: machines = set_machines(machines) asyncio.run(Collector.download(ctx.obj['config'], machines)) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context Loading @@ -342,8 +482,14 @@ def update(ctx: click.Context, machines: list[str]) -> None: @click.pass_context def clean(ctx: click.Context, machines: list[str]) -> None: machines = set_machines(machines) # TODO: Investigating nymfe55 if "nymfe55.fi.muni.cz" in machines: machines.remove("nymfe55.fi.muni.cz") asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.clean())) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context Loading @@ -351,6 +497,7 @@ def generate(ctx: click.Context, machines: list[str]) -> None: machines = set_machines(machines) asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.generate())) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context Loading @@ -359,6 +506,56 @@ def kill(ctx: click.Context, machines: list[str]) -> None: asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.kill())) async def read_status() -> None: (rd, wr) = await asyncio.open_unix_connection(UNIX_SOCK) while True: bs = await rd.read(8192) if len(bs) == 0: break raw = bson.loads(bs) data = {} done = 0 for (key, value) in raw.items(): data[(key.split('.'))[0]] = value if value in ['success', 'failed']: done += 1 max_key = max(map(lambda s: len(s), data.keys())) (width, _) = os.get_terminal_size() cols = max(1, width // (max_key + 12)) def chunks(lst, size): for i in range(0, len(lst), size): yield lst[i:i + size] click.clear() print(f"⌚{datetime.now().strftime('%F %T')} Status: {done} / {len(data)}") for chunk in chunks(list(data.keys()), cols): for machine in chunk: status = data[machine] name = f"{machine:{max_key}}" if status == 'success': print(f" \x1b[32m✔ {name}\x1b[0m ", end="") done += 1 elif status == 'failed': print(f" \x1b[91m✘ {name}\x1b[0m ", end="") done += 1 elif m := re.search(r'^(\d)/(\d)$', status): print(f" \x1b[93m*\x1b[0m \x1b[37m{name}\x1b[0m ({m[1]:>2}/{m[2]:>2})", end="") print("") @cli.command() @click.pass_context def status(ctx: click.Context) -> None: asyncio.run(read_status()) if __name__ == "__main__": config = Config(os.path.join(SCRIPT_BIN, "config")) cli(obj={ Loading Loading
tpm2-algtest-collect +205 −8 Original line number Diff line number Diff line #!env python3 #!/bin/env python3.10 import asyncio import asyncssh import bson import click import logging import os import random import re import sys from datetime import datetime from pathlib import Path from typing import Any, Optional, Type, TypeVar from typing import Any, Optional, Tuple, Type, TypeVar T = TypeVar('T') SCRIPT_BIN = os.path.dirname(os.path.realpath(__file__)) REPO_PATH = '/var/lib/tpm2-algtest-nspawn/' UNIX_SOCK = os.path.join(SCRIPT_BIN, "_sock") DEPOSITORY = '/mnt/local/tpm/' class Utils: Loading Loading @@ -180,6 +185,38 @@ class CompositeReporterFactory(ReporterFactory): *list(map(lambda f: f.get_for(machine), self.factories))) class AlgtestStatus: def __init__(self): self._data = {} self._cond = asyncio.Condition() self._n = 0 async def _notify(self): async with self._cond: self._n += 1 self._cond.notify_all() async def add(self, machine: str) -> None: self._data[machine] = 'started' await self._notify() async def set(self, machine: str, status: str) -> None: self._data[machine] = status await self._notify() def machines(self) -> list[str]: return sorted(self._data.keys()) def data(self) -> dict[str, str]: return self._data async def data_wait(self, n: int = -1) -> Tuple[dict[str, str], int]: async with self._cond: while n == self._n: await self._cond.wait() return (self._data, self._n) class SSHReportErrSession(asyncssh.SSHClientSession[str]): def __init__(self, reporter: ReporterInterface): self.reporter = reporter Loading @@ -194,21 +231,27 @@ class Collector: @staticmethod async def _run_commands_with_connection( reporter: ReporterInterface, status: AlgtestStatus, machine: str, conn: asyncssh.SSHClientConnection, commands: list[str]) -> ReturnCode: for command in commands: for (index, command) in enumerate(commands): reporter.command(command) await status.set(machine, f"{index + 1}/{len(commands)}") chan, session = await conn.create_session( lambda: SSHReportErrSession(reporter), command) await chan.wait_closed() rc = ReturnCode(chan.get_returncode()) if not rc: reporter.failure(command, rc) await status.set(machine, f"failed ({index})") return rc await status.set(machine, "success") return ReturnCode.success() @staticmethod Loading @@ -221,27 +264,32 @@ class Collector: @staticmethod async def _run_commands(reporter: ReporterInterface, status: AlgtestStatus, machine: str, commands: list[str]) -> ReturnCode: try: await status.add(machine) reporter.status("Connecting") async with asyncssh.connect( machine, config=None, options=Collector._ssh_options() ) as conn: await status.set(machine, f"0/{len(commands)}") reporter.status("Connected") return await Collector._run_commands_with_connection( reporter, machine, conn, commands reporter, status, machine, conn, commands ) except Exception as ex: await status.set(machine, 'failed') reporter.ex(ex) return ReturnCode(255) @staticmethod def test() -> list[str]: return [ "echo Hello >&2; exit 1" "echo Hello >&2", "sleep $((5 + RANDOM % 10))", ] @staticmethod Loading @@ -249,6 +297,7 @@ class Collector: return [ f"test -d {REPO_PATH} || git clone {config['tpm2_tools_remote']} {REPO_PATH}", f"git -C {REPO_PATH} clean -xf && git -C {REPO_PATH} restore .", f"git -C {REPO_PATH} checkout main", f"git -C {REPO_PATH} pull --rebase", ] Loading @@ -270,6 +319,18 @@ class Collector: f"ps aux|fgrep nspawn >&2", ] @staticmethod async def _run_tasks(reporter_factory: ReporterFactory, machines: list[str], status: AlgtestStatus, server, commands: list[str]): await asyncio.gather( *[Collector._run_commands(reporter_factory.get_for(machine), status, machine, commands) for machine in machines] ) await server.stop() @staticmethod async def run(config: Config, machines: list[str], commands: list[str]) -> None: reporter_factory: Optional[ReporterFactory] = None Loading @@ -282,17 +343,75 @@ class Collector: else: reporter_factory = ReporterFactory(LogReporter) status = AlgtestStatus() server = UnixServer(UNIX_SOCK, status) await server.start() await asyncio.gather( *[Collector._run_commands(reporter_factory.get_for(machine), machine, commands) server.serve(), Collector._run_tasks(reporter_factory, machines, status, server, commands) ) @staticmethod async def _download_assets(machine: str) -> None: try: async with asyncssh.connect( machine, config=None, options=Collector._ssh_options()) as conn: path = os.path.join(REPO_PATH, '*/out-*.zip') await asyncssh.scp((conn, path), DEPOSITORY) except Exception as ex: print(f"{machine}: {ex}", file=sys.stderr) @staticmethod async def download(config: Config, machines: list[str]) -> None: await asyncio.gather( *[Collector._download_assets(machine) for machine in machines] ) class UnixServer: _status: Optional[AlgtestStatus] = None def __init__(self, path: str, status: AlgtestStatus): self.path = path UnixServer._status = status async def start(self): self.sock = await asyncio.start_unix_server(UnixServer._serve, path=self.path) logging.info(f"Started UNIX socket server: {self.path}") async def stop(self): self.sock.close() async def serve(self): try: await self.sock.serve_forever() except asyncio.exceptions.CancelledError: logging.info("Stopped UNIX socket server") @staticmethod async def _serve(rd: asyncio.StreamReader, wr: asyncio.StreamWriter) -> None: assert UnixServer._status is not None n = -1 try: while p := await UnixServer._status.data_wait(n): (data, n) = p wr.write(bson.dumps(data)) await wr.drain() except Exception as ex: print(ex) wr.close() def set_machines(machines: list[str]) -> list[str]: if not machines or (len(machines) == 1 and machines[0] == 'all'): return MACHINES if (len(machines) == 1 and machines[0] in ['nymfe', 'musa']): return list(filter(lambda host: host.startswith(machines[0]), MACHINES)) return list( map(lambda host: f"{host}" if "." in host else f"{host}.fi.muni.cz", machines) Loading @@ -309,10 +428,23 @@ def setup_logging() -> None: @click.group() @click.option('--debug', '--no-debug', default=False, is_flag=True) @click.option('--socket') @click.option('--repo-path') @click.pass_context def cli(ctx: dict[str, Any], debug: bool) -> None: def cli(ctx: dict[str, Any], debug: bool, socket: Optional[str], repo_path: Optional[str]) -> None: setup_logging() if socket is not None: if '/' not in socket: socket = os.path.join(SCRIPT_BIN, socket) global UNIX_SOCK UNIX_SOCK = socket if repo_path is not None: global REPO_PATH REPO_PATH = repo_path if debug: logging.getLogger().setLevel(logging.DEBUG) asyncssh.set_log_level(logging.DEBUG) Loading @@ -329,6 +461,14 @@ def test(ctx: click.Context, machines: list[str]) -> None: asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.test())) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context def download(ctx: click.Context, machines: list[str]) -> None: machines = set_machines(machines) asyncio.run(Collector.download(ctx.obj['config'], machines)) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context Loading @@ -342,8 +482,14 @@ def update(ctx: click.Context, machines: list[str]) -> None: @click.pass_context def clean(ctx: click.Context, machines: list[str]) -> None: machines = set_machines(machines) # TODO: Investigating nymfe55 if "nymfe55.fi.muni.cz" in machines: machines.remove("nymfe55.fi.muni.cz") asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.clean())) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context Loading @@ -351,6 +497,7 @@ def generate(ctx: click.Context, machines: list[str]) -> None: machines = set_machines(machines) asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.generate())) @cli.command() @click.argument('machines', nargs=-1) @click.pass_context Loading @@ -359,6 +506,56 @@ def kill(ctx: click.Context, machines: list[str]) -> None: asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.kill())) async def read_status() -> None: (rd, wr) = await asyncio.open_unix_connection(UNIX_SOCK) while True: bs = await rd.read(8192) if len(bs) == 0: break raw = bson.loads(bs) data = {} done = 0 for (key, value) in raw.items(): data[(key.split('.'))[0]] = value if value in ['success', 'failed']: done += 1 max_key = max(map(lambda s: len(s), data.keys())) (width, _) = os.get_terminal_size() cols = max(1, width // (max_key + 12)) def chunks(lst, size): for i in range(0, len(lst), size): yield lst[i:i + size] click.clear() print(f"⌚{datetime.now().strftime('%F %T')} Status: {done} / {len(data)}") for chunk in chunks(list(data.keys()), cols): for machine in chunk: status = data[machine] name = f"{machine:{max_key}}" if status == 'success': print(f" \x1b[32m✔ {name}\x1b[0m ", end="") done += 1 elif status == 'failed': print(f" \x1b[91m✘ {name}\x1b[0m ", end="") done += 1 elif m := re.search(r'^(\d)/(\d)$', status): print(f" \x1b[93m*\x1b[0m \x1b[37m{name}\x1b[0m ({m[1]:>2}/{m[2]:>2})", end="") print("") @cli.command() @click.pass_context def status(ctx: click.Context) -> None: asyncio.run(read_status()) if __name__ == "__main__": config = Config(os.path.join(SCRIPT_BIN, "config")) cli(obj={ Loading