diff --git a/cyberdrop_dl/cli/__init__.py b/cyberdrop_dl/cli/__init__.py index 1eeb6f9e3..d7d1af189 100644 --- a/cyberdrop_dl/cli/__init__.py +++ b/cyberdrop_dl/cli/__init__.py @@ -1,304 +1,48 @@ +from __future__ import annotations + import dataclasses -import datetime import sys -import time -import warnings -from argparse import SUPPRESS, ArgumentParser, BooleanOptionalAction, RawDescriptionHelpFormatter -from argparse import _ArgumentGroup as ArgGroup -from collections.abc import Iterable, Sequence -from enum import StrEnum, auto -from pathlib import Path +from argparse import SUPPRESS, ArgumentParser, RawDescriptionHelpFormatter from shutil import get_terminal_size -from typing import Annotated, Any, Literal, NoReturn, Self +from typing import TYPE_CHECKING, Any, Final, NoReturn -from pydantic import BaseModel, Field, ValidationError, computed_field, field_validator, model_validator +from pydantic import BaseModel, ValidationError from cyberdrop_dl import __version__, env +from cyberdrop_dl.cli import arguments +from cyberdrop_dl.cli.model import CLIargs, ParsedArgs from cyberdrop_dl.config import ConfigSettings, GlobalSettings -from cyberdrop_dl.models import AliasModel -from cyberdrop_dl.models.types import HttpURL -from cyberdrop_dl.utils.yaml import handle_validation_error - - -class UIOptions(StrEnum): - DISABLED = auto() - ACTIVITY = auto() - SIMPLE = auto() - FULLSCREEN = auto() - - -warnings.simplefilter("always", DeprecationWarning) -WARNING_TIMEOUT = 5 # seconds - -def _check_mutually_exclusive(group: Iterable[Any], msg: str) -> None: - if sum(1 for value in group if value) >= 2: - raise ValueError(msg) +if TYPE_CHECKING: + from argparse import _ArgumentGroup as ArgGroup # pyright: ignore[reportPrivateUsage] + from collections.abc import Sequence def is_terminal_in_portrait() -> bool: """Check if CDL is being run in portrait mode based on a few conditions.""" - # Return True if running in portrait mode, False otherwise (landscape mode) - - def check_terminal_size() -> bool: - terminal_size = get_terminal_size() - width, height = terminal_size.columns, terminal_size.lines - aspect_ratio = width / height - - # High aspect ratios are likely to be in landscape mode - if aspect_ratio >= 3.2: - return False - - # Check for mobile device in portrait mode - if (aspect_ratio < 1.5 and height >= 40) or (width <= 85 and aspect_ratio < 2.3): - return True - - # Assume landscape mode for other cases - return False if env.PORTRAIT_MODE: return True - return check_terminal_size() - - -_NOT_SET: Any = object() - - -@dataclasses.dataclass(slots=True, frozen=True, kw_only=True) -class CommandOptions: - nargs: int | str | None = _NOT_SET - const: Any = _NOT_SET - - def as_dict(self) -> dict[str, Any]: - return {k: v for k, v in dataclasses.asdict(self).items() if v is not _NOT_SET} - - -class CommandLineOnlyArgs(BaseModel): - links: list[HttpURL] = Field( - default=[], - description="link(s) to content to download (passing multiple links is supported)", - ) - appdata_folder: Path | None = Field( - default=None, - description="AppData folder path", - ) - completed_after: datetime.date | None = Field( - default=None, - description="only retry downloads that were completed on or after this date", - ) - completed_before: datetime.date | None = Field( - default=None, - description="only retry downloads that were completed on or before this date", - ) - - config_file: Path | None = Field( - default=None, - description="path to the CDL settings.yaml file to load", - ) - - download: bool = Field( - default=False, - description="skips UI, start download immediately", - ) - download_tiktok_audios: bool = Field( - default=False, - description="download TikTok audios from posts and save them as separate files", - ) - download_tiktok_src_quality_videos: bool = Field( - default=False, - description="download TikTok videos in source quality", - ) - impersonate: Annotated[ - Literal[ - "chrome", - "edge", - "safari", - "safari_ios", - "chrome_android", - "firefox", - ] - | bool - | None, - CommandOptions(nargs="?", const=True), - ] = Field( - default=None, - description="Use this target as impersonation for all scrape requests", - ) - max_items_retry: int = Field( - default=0, - description="max number of links to retry", - ) - portrait: bool = Field( - default=is_terminal_in_portrait(), - description="force CDL to run with a vertical layout", - ) - print_stats: bool = Field( - default=True, - description="show stats report at the end of a run", - ) - retry_all: bool = Field( - default=False, - description="retry all downloads", - ) - retry_failed: bool = Field( - default=False, - description="retry failed downloads", - ) - retry_maintenance: bool = Field( - default=False, - description="retry download of maintenance files (bunkr). Requires files to be hashed", - ) - show_supported_sites: bool = Field( - default=False, - description="shows a list of supported sites and exits", - ) - ui: UIOptions = Field( - default=UIOptions.FULLSCREEN, - description="DISABLED, ACTIVITY, SIMPLE or FULLSCREEN", - ) - - @property - def retry_any(self) -> bool: - return any((self.retry_all, self.retry_failed, self.retry_maintenance)) - - @property - def fullscreen_ui(self) -> bool: - return self.ui == UIOptions.FULLSCREEN - - @computed_field - def __computed__(self) -> dict[str, bool]: - return {"retry_any": self.retry_any, "fullscreen_ui": self.fullscreen_ui} - - @model_validator(mode="after") - def mutually_exclusive(self) -> Self: - group1 = [self.links, self.retry_all, self.retry_failed, self.retry_maintenance] - msg1 = "`--links`, '--retry-all', '--retry-maintenace' and '--retry-failed' are mutually exclusive" - _check_mutually_exclusive(group1, msg1) - return self - - @field_validator("ui", mode="before") - @classmethod - def lower(cls, value: str) -> str: - return value.lower() - - -class DeprecatedArgs(BaseModel): ... - - -class ParsedArgs(AliasModel): - cli_only_args: CommandLineOnlyArgs = CommandLineOnlyArgs() - config_settings: ConfigSettings = ConfigSettings() - deprecated_args: DeprecatedArgs = DeprecatedArgs() - global_settings: GlobalSettings = GlobalSettings() - - def model_post_init(self, *_) -> None: - exit_on_warning = False - - if self.cli_only_args.retry_all or self.cli_only_args.retry_maintenance: - self.config_settings.runtime_options.ignore_history = True - - if ( - not self.cli_only_args.fullscreen_ui - or self.cli_only_args.retry_any - or self.cli_only_args.config_file - or self.config_settings.sorting.sort_downloads - ): - self.cli_only_args.download = True - - if warnings_to_emit := self.prepare_warnings(): - for msg in warnings_to_emit: - warnings.warn(msg, DeprecationWarning, stacklevel=10) - if exit_on_warning: - sys.exit(1) - - time.sleep(WARNING_TIMEOUT) - - def prepare_warnings(self) -> set[str]: - warnings_to_emit: set[str] = set() - - def add_warning_msg_from(field_name: str) -> None: - if not field_name: - return - info = DeprecatedArgs.model_fields[field_name].deprecated - warnings_to_emit.add(str(info)) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - pass - - return warnings_to_emit - - -def _add_args_from_model( - parser: ArgumentParser | ArgGroup, - model: type[BaseModel], - *, - cli_args: bool = False, - deprecated: bool = False, - prefix: str = "", -) -> None: - for name, field in model.model_fields.items(): - full_name = prefix + name - cli_name = full_name.replace("_", "-") - arg_type = type(field.default) - - if issubclass(arg_type, BaseModel): - _add_args_from_model(parser, arg_type, cli_args=cli_args, deprecated=deprecated, prefix=f"{cli_name}.") - continue - - if arg_type not in (list, set, bool): - arg_type = str - - help_text = field.description or "" - default = field.default if cli_args else SUPPRESS - default_options: dict[str, Any] = {"default": default, "dest": full_name, "help": help_text} - for meta in field.metadata: - if isinstance(meta, CommandOptions): - default_options |= meta.as_dict() - break - - name_or_flags = [f"--{cli_name}"] - alias = field.alias or field.validation_alias or field.serialization_alias - if alias and len(str(alias)) == 1: - name_or_flags.insert(0, f"-{alias}") - if arg_type is bool: - action = BooleanOptionalAction - default_options.pop("default") - if cli_args and not (cli_name == "portrait" and env.RUNNING_IN_TERMUX): - action = "store_false" if default else "store_true" - if deprecated: - default_options = default_options | {"default": SUPPRESS} - parser.add_argument(*name_or_flags, action=action, **default_options) - continue - - if cli_name == "links": - _ = default_options.pop("dest") - _ = parser.add_argument(cli_name, metavar="LINK(S)", nargs="*", action="extend", **default_options) - continue - - if arg_type in (list, set): - _ = parser.add_argument(*name_or_flags, nargs="*", action="extend", **default_options) - continue - - _ = parser.add_argument(*name_or_flags, type=arg_type, **default_options) + terminal_size = get_terminal_size() + width, height = terminal_size.columns, terminal_size.lines + aspect_ratio = width / height + # High aspect ratios are likely to be in landscape mode + if aspect_ratio >= 3.2: + return False -def _create_groups_from_nested_models(parser: ArgumentParser, model: type[BaseModel]) -> list[ArgGroup]: - groups: list[ArgGroup] = [] - for name, field in model.model_fields.items(): - submodel = field.annotation - assert submodel and issubclass(submodel, BaseModel) - submodel_group = parser.add_argument_group(name) - _add_args_from_model(submodel_group, submodel) - groups.append(submodel_group) + # Check for mobile device in portrait mode + if (aspect_ratio < 1.5 and height >= 40) or (width <= 85 and aspect_ratio < 2.3): + return True - return groups + # Assume landscape mode for other cases + return False class CustomHelpFormatter(RawDescriptionHelpFormatter): - MAX_HELP_POS = 80 - INDENT_INCREMENT = 2 + MAX_HELP_POS: Final = 80 + INDENT_INCREMENT: Final = 2 def __init__(self, prog: str, width: int | None = None) -> None: super().__init__(prog, self.INDENT_INCREMENT, self.MAX_HELP_POS, width) @@ -309,70 +53,71 @@ def _get_help_string(self, action) -> str | None: return action.help -USING_DEPRECATED_ARGS: bool = bool(DeprecatedArgs.model_fields) +@dataclasses.dataclass(slots=True) +class CLIParser: + parser: ArgumentParser + groups: dict[str, list[ArgGroup]] + + def parse_args(self, args: Sequence[str] | None = None) -> dict[str, dict[str, Any]]: + return self._unflatten(self._parse_args(args)) + + def _parse_args(self, args: Sequence[str] | None = None) -> dict[str, Any]: + return dict(sorted(vars(self.parser.parse_intermixed_args(args)).items())) + + def _unflatten(self, namespace: dict[str, Any]) -> dict[str, dict[str, Any]]: + parsed_args: dict[str, dict[str, Any]] = {} + for name, groups in self.groups.items(): + parsed_args[name] = {} + for group in groups: + group_dict = {arg.dest: v for arg in group._group_actions if (v := namespace.get(arg.dest)) is not None} + if group_dict: + assert group.title + parsed_args[name][group.title] = _unflatten_nested_args(group_dict) -def make_parser() -> tuple[ArgumentParser, dict[str, list[ArgGroup]]]: + parsed_args["cli_only_args"] = parsed_args["cli_only_args"]["CLI-only options"] + return parsed_args + + +def make_parser() -> CLIParser: + kwargs: dict[str, Any] = {"color": True} if sys.version_info > (3, 14) else {} parser = ArgumentParser( description="Bulk asynchronous downloader for multiple file hosts", usage="cyberdrop-dl [OPTIONS] URL [URL...]", + allow_abbrev=False, formatter_class=CustomHelpFormatter, + **kwargs, ) _ = parser.add_argument("-V", "--version", action="version", version=f"%(prog)s {__version__}") cli_only = parser.add_argument_group("CLI-only options") - _add_args_from_model(cli_only, CommandLineOnlyArgs, cli_args=True) + _add_args_from_model(cli_only, CLIargs) - groups_mapping = { + groups = { "config_settings": _create_groups_from_nested_models(parser, ConfigSettings), "global_settings": _create_groups_from_nested_models(parser, GlobalSettings), "cli_only_args": [cli_only], } - if USING_DEPRECATED_ARGS: - deprecated = parser.add_argument_group("deprecated") - _add_args_from_model(deprecated, DeprecatedArgs, cli_args=True, deprecated=True) - groups_mapping["deprecated_args"] = [deprecated] - - return parser, groups_mapping - - -def get_parsed_args_dict(args: Sequence[str] | None = None) -> dict[str, dict[str, Any]]: - parser, groups_mapping = make_parser() - namespace = parser.parse_intermixed_args(args) - parsed_args: dict[str, dict[str, Any]] = {} - for name, groups in groups_mapping.items(): - parsed_args[name] = {} - for group in groups: - group_dict = { - arg.dest: getattr(namespace, arg.dest) - for arg in group._group_actions - if getattr(namespace, arg.dest, None) is not None - } - if group_dict: - assert group.title - parsed_args[name][group.title] = parse_nested_values(group_dict) - - if USING_DEPRECATED_ARGS: - parsed_args["deprecated_args"] = parsed_args["deprecated_args"].get("deprecated") or {} - parsed_args["cli_only_args"] = parsed_args["cli_only_args"]["CLI-only options"] - return parsed_args + return CLIParser(parser, groups) def parse_args(args: Sequence[str] | None = None) -> ParsedArgs: """Parses the command line arguments passed into the program.""" - parsed_args_dict = get_parsed_args_dict(args) + from cyberdrop_dl.utils.yaml import handle_validation_error + + parsed_args = make_parser().parse_args(args) try: - parsed_args_model = ParsedArgs.model_validate(parsed_args_dict) + model = ParsedArgs.model_validate(parsed_args, extra="forbid") except ValidationError as e: handle_validation_error(e, title="CLI arguments") sys.exit(1) - if parsed_args_model.cli_only_args.show_supported_sites: + if model.cli_only_args.show_supported_sites: show_supported_sites() - return parsed_args_model + return model def show_supported_sites() -> NoReturn: @@ -385,10 +130,10 @@ def show_supported_sites() -> NoReturn: sys.exit(0) -def parse_nested_values(data_list: dict[str, Any]) -> dict[str, Any]: +def _unflatten_nested_args(data: dict[str, Any]) -> dict[str, Any]: result: dict[str, Any] = {} - for command_name, value in data_list.items(): + for command_name, value in data.items(): inner_names = command_name.split(".") current_level = result for index, key in enumerate(inner_names): @@ -399,3 +144,28 @@ def parse_nested_values(data_list: dict[str, Any]) -> dict[str, Any]: else: current_level[key] = value return result + + +def _add_args_from_model(parser: ArgumentParser | ArgGroup, model: type[BaseModel]) -> None: + cli_args = model is CLIargs + + for arg in arguments.parse(model): + options = arg.compose_options() + + if cli_args and arg.arg_type is bool and not (arg.cli_name == "portrait" and env.RUNNING_IN_TERMUX): + default = arg.default if cli_args else SUPPRESS + options["action"] = "store_false" if default else "store_true" + + _ = parser.add_argument(*arg.name_or_flags, **options) + + +def _create_groups_from_nested_models(parser: ArgumentParser, model: type[BaseModel]) -> list[ArgGroup]: + groups: list[ArgGroup] = [] + for name, field in model.model_fields.items(): + submodel = field.annotation + assert submodel and issubclass(submodel, BaseModel) + submodel_group = parser.add_argument_group(name) + _add_args_from_model(submodel_group, submodel) + groups.append(submodel_group) + + return groups diff --git a/cyberdrop_dl/cli/arguments.py b/cyberdrop_dl/cli/arguments.py new file mode 100644 index 000000000..85536a4ee --- /dev/null +++ b/cyberdrop_dl/cli/arguments.py @@ -0,0 +1,122 @@ +import dataclasses +from argparse import BooleanOptionalAction +from collections.abc import Generator, Iterable +from typing import Any, Literal, TypedDict + +from pydantic import BaseModel + +_NOT_SET: Any = object() + + +class _ArgumentParams(TypedDict, total=False): + action: str + nargs: int | str | None + const: Any + default: Any + choices: Iterable[Any] | None + required: bool + help: str | None + metavar: str | tuple[str, ...] | None + dest: str | None + + +@dataclasses.dataclass(slots=True, frozen=True, kw_only=True) +class ArgumentParams: + positional_only: bool = dataclasses.field(default=False, metadata={"exclude": True}) + nargs: Literal["?", "*", "+"] | None = _NOT_SET + const: Any = _NOT_SET + dest: str = _NOT_SET + choices: Iterable[Any] | None = _NOT_SET + metavar: str | tuple[str, ...] | None = _NOT_SET + + def as_dict(self) -> _ArgumentParams: + return {name: v for name in _params if (v := getattr(self, name)) is not _NOT_SET} # pyright: ignore[reportReturnType] + + +_params = tuple(f.name for f in dataclasses.fields(ArgumentParams) if not f.metadata.get("exclude")) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class Argument: + name_or_flags: list[str] = dataclasses.field(init=False) + python_name: str + cli_name: str = dataclasses.field(init=False) + aliases: tuple[str, ...] + required: bool + default: Any + annotation: Any + help: str | None + metadata: list[Any] + positional_only: bool = dataclasses.field(init=False) + arg_type: type = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.cli_name = self.python_name.replace("_", "-") + self.arg_type = type(self.default) + + if self.arg_type not in (list, set, bool): + self.arg_type = str + + self.positional_only = override.positional_only if (override := self._overrides()) else False + cli_command = f"{'' if self.positional_only else '--'}{self.cli_name}" + self.name_or_flags = [cli_command] + + for alias in self.aliases: + if alias and len(alias) == 1: + self.name_or_flags.insert(0, f"-{alias}") + else: + self.name_or_flags.append(alias) + + def compose_options(self) -> _ArgumentParams: + options = self._options() + if override := self._overrides(): + return options | override.as_dict() + + return options + + def _overrides(self) -> ArgumentParams | None: + for meta in self.metadata: + if isinstance(meta, ArgumentParams): + return meta + + def _options(self) -> _ArgumentParams: + options = dict( # noqa: C408 + default=self.default, + help=self.help, + action="store", + ) + if not self.positional_only: + options["dest"] = self.python_name + + if self.arg_type is bool: + options["action"] = BooleanOptionalAction + + elif self.arg_type in (list, set): + options.update(nargs="*", action="extend") + + else: + options["type"] = self.arg_type + + return options # pyright: ignore[reportReturnType] + + +def parse(model: type[BaseModel]) -> Generator[Argument]: + for python_name, field in model.model_fields.items(): + aliases = filter( + None, + ( + field.alias, + field.validation_alias, + field.serialization_alias, + ), + ) + + yield Argument( + python_name=python_name, + aliases=tuple(map(str, aliases)), + annotation=field.annotation, + default=field.default, + required=field.is_required(), + metadata=field.metadata, + help=field.description or None, + ) diff --git a/cyberdrop_dl/cli/model.py b/cyberdrop_dl/cli/model.py new file mode 100644 index 000000000..59a1ad8b5 --- /dev/null +++ b/cyberdrop_dl/cli/model.py @@ -0,0 +1,153 @@ +import datetime +from collections.abc import Iterable +from enum import StrEnum, auto +from pathlib import Path +from typing import Annotated, Any, Literal, Self + +from pydantic import BaseModel, Field, computed_field, field_validator, model_validator + +from cyberdrop_dl.cli.arguments import ArgumentParams +from cyberdrop_dl.config import ConfigSettings, GlobalSettings +from cyberdrop_dl.models.types import HttpURL + + +class UIOptions(StrEnum): + DISABLED = auto() + ACTIVITY = auto() + SIMPLE = auto() + FULLSCREEN = auto() + + +class CLIargs(BaseModel): + links: Annotated[ + list[HttpURL], + ArgumentParams(positional_only=True, metavar="LINK(s)"), + ] = Field( + default=[], + description="link(s) to content to download (passing multiple links is supported)", + ) + appdata_folder: Path | None = Field( + default=None, + description="AppData folder path", + ) + completed_after: datetime.date | None = Field( + default=None, + description="only retry downloads that were completed on or after this date", + ) + completed_before: datetime.date | None = Field( + default=None, + description="only retry downloads that were completed on or before this date", + ) + + config_file: Path | None = Field( + default=None, + description="path to the CDL settings.yaml file to load", + ) + + download: bool = Field( + default=False, + description="skips UI, start download immediately", + ) + download_tiktok_audios: bool = Field( + default=False, + description="download TikTok audios from posts and save them as separate files", + ) + download_tiktok_src_quality_videos: bool = Field( + default=False, + description="download TikTok videos in source quality", + ) + impersonate: Annotated[ + Literal[ + "chrome", + "edge", + "safari", + "safari_ios", + "chrome_android", + "firefox", + ] + | bool + | None, + ArgumentParams(nargs="?", const=True), + ] = Field( + default=None, + description="Use this target as impersonation for all scrape requests", + ) + max_items_retry: int = Field( + default=0, + description="max number of links to retry", + ) + portrait: bool = Field( + default=False, + description="force CDL to run with a vertical layout", + ) + print_stats: bool = Field( + default=True, + description="show stats report at the end of a run", + ) + retry_all: bool = Field( + default=False, + description="retry all downloads", + ) + retry_failed: bool = Field( + default=False, + description="retry failed downloads", + ) + retry_maintenance: bool = Field( + default=False, + description="retry download of maintenance files (bunkr). Requires files to be hashed", + ) + show_supported_sites: bool = Field( + default=False, + description="shows a list of supported sites and exits", + ) + ui: UIOptions = Field( + default=UIOptions.FULLSCREEN, + description="DISABLED, ACTIVITY, SIMPLE or FULLSCREEN", + ) + + @property + def retry_any(self) -> bool: + return any((self.retry_all, self.retry_failed, self.retry_maintenance)) + + @property + def fullscreen_ui(self) -> bool: + return self.ui == UIOptions.FULLSCREEN + + @computed_field + def __computed__(self) -> dict[str, bool]: + return {"retry_any": self.retry_any, "fullscreen_ui": self.fullscreen_ui} + + @model_validator(mode="after") + def mutually_exclusive(self) -> Self: + group1 = [self.links, self.retry_all, self.retry_failed, self.retry_maintenance] + msg1 = "`--links`, '--retry-all', '--retry-maintenace' and '--retry-failed' are mutually exclusive" + _check_mutually_exclusive(group1, msg1) + return self + + @field_validator("ui", mode="before") + @classmethod + def lower(cls, value: str) -> str: + return value.lower() + + +def _check_mutually_exclusive(group: Iterable[Any], msg: str) -> None: + if sum(1 for value in group if value) >= 2: + raise ValueError(msg) + + +class ParsedArgs(BaseModel): + cli_only_args: CLIargs = CLIargs() + config_settings: ConfigSettings = ConfigSettings() + global_settings: GlobalSettings = GlobalSettings() + + def model_post_init(self, *_) -> None: + if self.cli_only_args.retry_all or self.cli_only_args.retry_maintenance: + self.config_settings.runtime_options.ignore_history = True + + if ( + not self.cli_only_args.fullscreen_ui + or self.cli_only_args.retry_any + or self.cli_only_args.config_file + or self.config_settings.sorting.sort_downloads + ): + self.cli_only_args.download = True diff --git a/cyberdrop_dl/config/config_model.py b/cyberdrop_dl/config/config_model.py index c8e58fdac..0b33aa67e 100755 --- a/cyberdrop_dl/config/config_model.py +++ b/cyberdrop_dl/config/config_model.py @@ -66,7 +66,7 @@ def valid_format(cls, value: str) -> str: class Files(AliasModel): download_folder: Path = Field(default=DEFAULT_DOWNLOAD_STORAGE, validation_alias="d") dump_json: bool = Field(default=False, validation_alias="j") - input_file: Path = Field(default=DEFAULT_APP_STORAGE / "Configs{config}/URLs.txt", validation_alias="i") + input_file: Path = Field(default=DEFAULT_APP_STORAGE / "Configs/{config}/URLs.txt", validation_alias="i") save_pages_html: bool = False diff --git a/scripts/tools/update_docs.py b/scripts/tools/update_docs.py index 6de56331e..e235aa4fd 100644 --- a/scripts/tools/update_docs.py +++ b/scripts/tools/update_docs.py @@ -1,7 +1,7 @@ from pathlib import Path from cyberdrop_dl import __version__ -from cyberdrop_dl.cli import CDL_EPILOG, CustomHelpFormatter, make_parser +from cyberdrop_dl.cli import CustomHelpFormatter, make_parser from cyberdrop_dl.utils.markdown import get_crawlers_info_as_markdown_table REPO_ROOT = Path(__file__).parents[2] @@ -10,7 +10,7 @@ def update_cli_overview() -> None: - parser, _ = make_parser() + parser = make_parser().parser def get_wide_formatter(_=None) -> CustomHelpFormatter: return CustomHelpFormatter(parser.prog, width=300) @@ -18,7 +18,7 @@ def get_wide_formatter(_=None) -> CustomHelpFormatter: parser._get_formatter = get_wide_formatter help_text = parser.format_help() shell = "```shell" - cli_overview, *_ = help_text.partition(CDL_EPILOG) + cli_overview, *_ = help_text.partition("") current_content = CLI_ARGUMENTS_MD.read_text(encoding="utf8") new_content, *_ = current_content.partition(shell) new_content += f"{shell}\n{cli_overview}```\n" diff --git a/tests/test_cli.py b/tests/test_cli.py index a6b190981..6fc36026f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -26,7 +26,7 @@ def test_command_by_console_output(tmp_cwd: Path, capsys: pytest.CaptureFixture[ def test_startup_logger_should_not_be_created_on_a_successful_run(tmp_cwd: Path) -> None: run("--download") - startup_file = Path.cwd() / "startup.log" + startup_file = tmp_cwd / "startup.log" assert not startup_file.exists() @@ -41,7 +41,7 @@ def test_startup_logger_should_not_be_created_on_invalid_cookies(tmp_cwd: Path) logs = director.manager.path_manager.main_log.read_text(encoding="utf8") assert "does not look like a Netscape format cookies file" in logs - startup_file = Path.cwd() / "startup.log" + startup_file = tmp_cwd / "startup.log" assert not startup_file.exists() @@ -56,7 +56,7 @@ def test_startup_logger_is_created_on_yaml_error(tmp_cwd: Path) -> None: except SystemExit: pass - startup_file = Path.cwd() / "startup.log" + startup_file = tmp_cwd / "startup.log" assert startup_file.exists() logs = startup_file.read_text(encoding="utf8") @@ -80,7 +80,7 @@ def test_startup_logger_when_manager_startup_fails( run("--download") except SystemExit: pass - startup_file = Path.cwd() / "startup.log" + startup_file = tmp_cwd / "startup.log" assert startup_file.exists() == exists @@ -89,7 +89,7 @@ def test_startup_logger_should_not_be_created_when_using_invalid_cli_args(tmp_cw run("--invalid-command") except SystemExit: pass - startup_file = Path.cwd() / "startup.log" + startup_file = tmp_cwd / "startup.log" assert not startup_file.exists()