Verified Commit c526fadf authored by Roman Lacko's avatar Roman Lacko
Browse files

tpm2-algtest-collect: Finalized data collection

parent eaf85994
Loading
Loading
Loading
Loading
+205 −8
Original line number Diff line number Diff line
#!env python3
#!/bin/env python3.10

import asyncio
import asyncssh
import bson
import click
import logging
import os
import random
import re
import sys

from datetime import datetime
from pathlib import Path
from typing import Any, Optional, Type, TypeVar
from typing import Any, Optional, Tuple, Type, TypeVar

T = TypeVar('T')
SCRIPT_BIN = os.path.dirname(os.path.realpath(__file__))
REPO_PATH = '/var/lib/tpm2-algtest-nspawn/'
UNIX_SOCK = os.path.join(SCRIPT_BIN, "_sock")
DEPOSITORY = '/mnt/local/tpm/'


class Utils:
@@ -180,6 +185,38 @@ class CompositeReporterFactory(ReporterFactory):
                *list(map(lambda f: f.get_for(machine), self.factories)))


class AlgtestStatus:
    def __init__(self):
        self._data = {}
        self._cond = asyncio.Condition()
        self._n = 0

    async def _notify(self):
        async with self._cond:
            self._n += 1
            self._cond.notify_all()

    async def add(self, machine: str) -> None:
        self._data[machine] = 'started'
        await self._notify()

    async def set(self, machine: str, status: str) -> None:
        self._data[machine] = status
        await self._notify()

    def machines(self) -> list[str]:
        return sorted(self._data.keys())

    def data(self) -> dict[str, str]:
        return self._data

    async def data_wait(self, n: int = -1) -> Tuple[dict[str, str], int]:
        async with self._cond:
            while n == self._n:
                await self._cond.wait()
            return (self._data, self._n)


class SSHReportErrSession(asyncssh.SSHClientSession[str]):
    def __init__(self, reporter: ReporterInterface):
        self.reporter = reporter
@@ -194,21 +231,27 @@ class Collector:
    @staticmethod
    async def _run_commands_with_connection(
            reporter: ReporterInterface,
            status: AlgtestStatus,
            machine: str,
            conn: asyncssh.SSHClientConnection,
            commands: list[str]) -> ReturnCode:

        for command in commands:
        for (index, command) in enumerate(commands):
            reporter.command(command)
            await status.set(machine, f"{index + 1}/{len(commands)}")

            chan, session = await conn.create_session(
                    lambda: SSHReportErrSession(reporter), command)

            await chan.wait_closed()
            rc = ReturnCode(chan.get_returncode())

            if not rc:
                reporter.failure(command, rc)
                await status.set(machine, f"failed ({index})")
                return rc

        await status.set(machine, "success")
        return ReturnCode.success()

    @staticmethod
@@ -221,27 +264,32 @@ class Collector:

    @staticmethod
    async def _run_commands(reporter: ReporterInterface,
                            status: AlgtestStatus,
                            machine: str,
                            commands: list[str]) -> ReturnCode:
        try:
            await status.add(machine)
            reporter.status("Connecting")
            async with asyncssh.connect(
                        machine, config=None, options=Collector._ssh_options()
                    ) as conn:

                await status.set(machine, f"0/{len(commands)}")
                reporter.status("Connected")
                return await Collector._run_commands_with_connection(
                        reporter, machine, conn, commands
                        reporter, status, machine, conn, commands
                )

        except Exception as ex:
            await status.set(machine, 'failed')
            reporter.ex(ex)
            return ReturnCode(255)

    @staticmethod
    def test() -> list[str]:
        return [
            "echo Hello >&2; exit 1"
            "echo Hello >&2",
            "sleep $((5 + RANDOM % 10))",
        ]

    @staticmethod
@@ -249,6 +297,7 @@ class Collector:
        return [
            f"test -d {REPO_PATH} || git clone {config['tpm2_tools_remote']} {REPO_PATH}",
            f"git -C {REPO_PATH} clean -xf && git -C {REPO_PATH} restore .",
            f"git -C {REPO_PATH} checkout main",
            f"git -C {REPO_PATH} pull --rebase",
        ]

@@ -270,6 +319,18 @@ class Collector:
            f"ps aux|fgrep nspawn >&2",
        ]

    @staticmethod
    async def _run_tasks(reporter_factory: ReporterFactory, machines: list[str],
                         status: AlgtestStatus, server, commands: list[str]):

        await asyncio.gather(
            *[Collector._run_commands(reporter_factory.get_for(machine),
                                      status, machine, commands)
                for machine in machines]
        )

        await server.stop()

    @staticmethod
    async def run(config: Config, machines: list[str], commands: list[str]) -> None:
        reporter_factory: Optional[ReporterFactory] = None
@@ -282,17 +343,75 @@ class Collector:
        else:
            reporter_factory = ReporterFactory(LogReporter)

        status = AlgtestStatus()
        server = UnixServer(UNIX_SOCK, status)
        await server.start()

        await asyncio.gather(
            *[Collector._run_commands(reporter_factory.get_for(machine),
                                      machine, commands)
            server.serve(),
            Collector._run_tasks(reporter_factory, machines, status, server, commands)
        )

    @staticmethod
    async def _download_assets(machine: str) -> None:
        try:
            async with asyncssh.connect(
                    machine, config=None, options=Collector._ssh_options()) as conn:
                path = os.path.join(REPO_PATH, '*/out-*.zip')
                await asyncssh.scp((conn, path), DEPOSITORY)
        except Exception as ex:
            print(f"{machine}: {ex}", file=sys.stderr)


    @staticmethod
    async def download(config: Config, machines: list[str]) -> None:
        await asyncio.gather(
            *[Collector._download_assets(machine)
                for machine in machines]
        )

class UnixServer:
    _status: Optional[AlgtestStatus] = None

    def __init__(self, path: str, status: AlgtestStatus):
        self.path = path
        UnixServer._status = status

    async def start(self):
        self.sock = await asyncio.start_unix_server(UnixServer._serve, path=self.path)
        logging.info(f"Started UNIX socket server: {self.path}")

    async def stop(self):
        self.sock.close()

    async def serve(self):
        try:
            await self.sock.serve_forever()
        except asyncio.exceptions.CancelledError:
            logging.info("Stopped UNIX socket server")

    @staticmethod
    async def _serve(rd: asyncio.StreamReader, wr: asyncio.StreamWriter) -> None:
        assert UnixServer._status is not None

        n = -1
        try:
            while p := await UnixServer._status.data_wait(n):
                (data, n) = p
                wr.write(bson.dumps(data))
                await wr.drain()
        except Exception as ex:
            print(ex)
            wr.close()


def set_machines(machines: list[str]) -> list[str]:
    if not machines or (len(machines) == 1 and machines[0] == 'all'):
        return MACHINES

    if (len(machines) == 1 and machines[0] in ['nymfe', 'musa']):
        return list(filter(lambda host: host.startswith(machines[0]), MACHINES))

    return list(
            map(lambda host: f"{host}" if "." in host else f"{host}.fi.muni.cz",
                machines)
@@ -309,10 +428,23 @@ def setup_logging() -> None:

@click.group()
@click.option('--debug', '--no-debug', default=False, is_flag=True)
@click.option('--socket')
@click.option('--repo-path')
@click.pass_context
def cli(ctx: dict[str, Any], debug: bool) -> None:
def cli(ctx: dict[str, Any], debug: bool, socket: Optional[str],
        repo_path: Optional[str]) -> None:
    setup_logging()

    if socket is not None:
        if '/' not in socket:
            socket = os.path.join(SCRIPT_BIN, socket)
        global UNIX_SOCK
        UNIX_SOCK = socket

    if repo_path is not None:
        global REPO_PATH
        REPO_PATH = repo_path

    if debug:
        logging.getLogger().setLevel(logging.DEBUG)
        asyncssh.set_log_level(logging.DEBUG)
@@ -329,6 +461,14 @@ def test(ctx: click.Context, machines: list[str]) -> None:
    asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.test()))


@cli.command()
@click.argument('machines', nargs=-1)
@click.pass_context
def download(ctx: click.Context, machines: list[str]) -> None:
    machines = set_machines(machines)
    asyncio.run(Collector.download(ctx.obj['config'], machines))


@cli.command()
@click.argument('machines', nargs=-1)
@click.pass_context
@@ -342,8 +482,14 @@ def update(ctx: click.Context, machines: list[str]) -> None:
@click.pass_context
def clean(ctx: click.Context, machines: list[str]) -> None:
    machines = set_machines(machines)

    # TODO: Investigating nymfe55
    if "nymfe55.fi.muni.cz" in machines:
        machines.remove("nymfe55.fi.muni.cz")

    asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.clean()))


@cli.command()
@click.argument('machines', nargs=-1)
@click.pass_context
@@ -351,6 +497,7 @@ def generate(ctx: click.Context, machines: list[str]) -> None:
    machines = set_machines(machines)
    asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.generate()))


@cli.command()
@click.argument('machines', nargs=-1)
@click.pass_context
@@ -359,6 +506,56 @@ def kill(ctx: click.Context, machines: list[str]) -> None:
    asyncio.run(Collector.run(ctx.obj['config'], machines, Collector.kill()))


async def read_status() -> None:
    (rd, wr) = await asyncio.open_unix_connection(UNIX_SOCK)

    while True:
        bs = await rd.read(8192)
        if len(bs) == 0:
            break
        raw = bson.loads(bs)
        data = {}

        done = 0
        for (key, value) in raw.items():
            data[(key.split('.'))[0]] = value

            if value in ['success', 'failed']:
                done += 1

        max_key = max(map(lambda s: len(s), data.keys()))
        (width, _) = os.get_terminal_size()

        cols = max(1, width // (max_key + 12))

        def chunks(lst, size):
            for i in range(0, len(lst), size):
                yield lst[i:i + size]

        click.clear()
        print(f"{datetime.now().strftime('%F %T')}  Status: {done} / {len(data)}")
        for chunk in chunks(list(data.keys()), cols):
            for machine in chunk:
                status = data[machine]

                name = f"{machine:{max_key}}"
                if status == 'success':
                    print(f" \x1b[32m✔ {name}\x1b[0m        ", end="")
                    done += 1
                elif status == 'failed':
                    print(f" \x1b[91m✘ {name}\x1b[0m        ", end="")
                    done += 1
                elif m := re.search(r'^(\d)/(\d)$', status):
                    print(f" \x1b[93m*\x1b[0m \x1b[37m{name}\x1b[0m ({m[1]:>2}/{m[2]:>2})", end="")
            print("")


@cli.command()
@click.pass_context
def status(ctx: click.Context) -> None:
    asyncio.run(read_status())


if __name__ == "__main__":
    config = Config(os.path.join(SCRIPT_BIN, "config"))
    cli(obj={