Commit d998cfeb authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Add --gpus command-line option

parent 6f9703ee
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Original line Diff line number Diff line
@@ -19,7 +19,7 @@ to directory output/:
To achieve the best results, you should also enable image super-resolution
To achieve the best results, you should also enable image super-resolution
(requires GPU) and Google Vision AI (requires paid account):
(requires GPU) and Google Vision AI (requires paid account):


    ahisto-ocr --super-resolution --google-vision-ai --google-api-key key_file input/ output/
    ahisto-ocr --super-resolution --google-vision-ai --gpus 10,11,12 --google-api-key key_file input/ output/


Here is example output of the tool for two images:
Here is example output of the tool for two images:


+8 −4
Original line number Original line Diff line number Diff line
@@ -37,8 +37,12 @@ LOGGER = getLogger(__name__)
              default=False,
              default=False,
              is_flag=True,
              is_flag=True,
              required=False)
              required=False)
@click.option('--gpus',
              help='Comma-separated PCI BUS IDs of NVIDIA GPUs that will be used for image super-resolution',
              default='all',
              required=False)
def main(input_dir: str, output_dir: str, google_vision_ai: bool,
def main(input_dir: str, output_dir: str, google_vision_ai: bool,
         google_api_key: Optional[str], super_resolution: bool) -> None:
         google_api_key: Optional[str], super_resolution: bool, gpus: str) -> None:


    if google_vision_ai and google_api_key is None:
    if google_vision_ai and google_api_key is None:
        raise ValueError('Cannot use Google Vision AI without an API key.')
        raise ValueError('Cannot use Google Vision AI without an API key.')
@@ -67,11 +71,11 @@ def main(input_dir: str, output_dir: str, google_vision_ai: bool,


    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')


    run_ocr(input_dir, input_images, output_dir, google_vision_ai, google_api_key, super_resolution)
    run_ocr(input_dir, input_images, output_dir, google_vision_ai, google_api_key, super_resolution, gpus)




def run_ocr(input_dir: Path, input_images: List[Path], output_dir: Path, google_vision_ai: bool,
def run_ocr(input_dir: Path, input_images: List[Path], output_dir: Path, google_vision_ai: bool,
            google_api_key: Optional[Path], super_resolution: bool) -> None:
            google_api_key: Optional[Path], super_resolution: bool, gpus: str) -> None:
    client = docker.from_env()
    client = docker.from_env()
    with create_temporary_docker_volume(client) as input_volume, \
    with create_temporary_docker_volume(client) as input_volume, \
            create_temporary_docker_volume(client) as postprocessing_volume, \
            create_temporary_docker_volume(client) as postprocessing_volume, \
@@ -82,7 +86,7 @@ def run_ocr(input_dir: Path, input_images: List[Path], output_dir: Path, google_
        copy_input_to(client, input_volume, input_dir, input_images, google_api_key)
        copy_input_to(client, input_volume, input_dir, input_images, google_api_key)


        if super_resolution:
        if super_resolution:
            apply_super_resolution(client, input_volume, postprocessing_volume)
            apply_super_resolution(client, input_volume, postprocessing_volume, gpus)
        else:
        else:
            LOGGER.info('Skipping image super-resolution')
            LOGGER.info('Skipping image super-resolution')
            postprocessing_volume = input_volume
            postprocessing_volume = input_volume
+0 −1
Original line number Original line Diff line number Diff line
[preprocessing]
[preprocessing]
nvidia_visible_devices = all
model = anime_style_art_rgb
model = anime_style_art_rgb
noise_level = 3
noise_level = 3


+2 −2
Original line number Original line Diff line number Diff line
@@ -14,7 +14,7 @@ LOGGER = getLogger(__name__)
CONFIG = _CONFIG['preprocessing']
CONFIG = _CONFIG['preprocessing']




def apply_super_resolution(client, input_volume, postprocessing_volume) -> None:
def apply_super_resolution(client, input_volume, postprocessing_volume, gpus: str) -> None:
    LOGGER.info('Pre-processing the input images using super-resolution')
    LOGGER.info('Pre-processing the input images using super-resolution')


    volumes = {
    volumes = {
@@ -52,7 +52,7 @@ def apply_super_resolution(client, input_volume, postprocessing_volume) -> None:
        '-scale', '2', '-noise_level', str(CONFIG.getint('noise_level')), '-l', '/input/list.txt',
        '-scale', '2', '-noise_level', str(CONFIG.getint('noise_level')), '-l', '/input/list.txt',
    ]
    ]
    run_docker_container(client, 'ahisto/waifu2x', runtime='nvidia',
    run_docker_container(client, 'ahisto/waifu2x', runtime='nvidia',
                         environment={'NVIDIA_VISIBLE_DEVICES': CONFIG['nvidia_visible_devices']},
                         environment={'NVIDIA_VISIBLE_DEVICES': gpus},
                         command=command, volumes=volumes)
                         command=command, volumes=volumes)


    with create_temporary_docker_container(client, 'ahisto/empty', command='cmd',
    with create_temporary_docker_container(client, 'ahisto/empty', command='cmd',