diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index e1db652..2ff4992 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,2 +1,5 @@ 1e76c26c673b709f3296d567f60954a373169fb5 +235927ce73125df31ec3a0049b067afa1f0a135b +25eefd52d023f97870d8b4a27988f8fc91c3ed73 +7d8c46cce63ff1b93836b72cdad74ec796b09ced diff --git a/.gitignore b/.gitignore index d8a1bb0..3bdf40f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ __pycache__ .coverage +.vscode diff --git a/relink.py b/relink.py index b4f59d4..e9f2cc5 100644 --- a/relink.py +++ b/relink.py @@ -5,17 +5,24 @@ """ import os -import sys import pwd import argparse import logging import time from pathlib import Path -from shared import DEFAULT_INPUTDATA_ROOT, DEFAULT_STAGING_ROOT - -# Set up logger -logger = logging.getLogger(__name__) +from shared import ( + DEFAULT_INPUTDATA_ROOT, + DEFAULT_STAGING_ROOT, + get_log_level, + add_parser_verbosity_group, + add_inputdata_root, + validate_paths, + validate_directory, + configure_logging, + logger, + INDENT, +) # Define a custom log level that always prints ALWAYS = logging.CRITICAL * 2 @@ -157,9 +164,7 @@ def find_owned_files_scandir(item, user_uid, inputdata_root=DEFAULT_INPUTDATA_RO ) # Things other than directories are handled separately - elif ( - entry_path := handle_non_dir(entry, user_uid) - ) is not None: + elif (entry_path := handle_non_dir(entry, user_uid)) is not None: yield entry_path except (OSError, PermissionError) as e: @@ -175,7 +180,11 @@ def find_owned_files_scandir(item, user_uid, inputdata_root=DEFAULT_INPUTDATA_RO def replace_files_with_symlinks( - item_to_process, target_dir, username, inputdata_root=DEFAULT_INPUTDATA_ROOT, dry_run=False + item_to_process, + target_dir, + username, + inputdata_root=DEFAULT_INPUTDATA_ROOT, + dry_run=False, ): """ Finds files owned by a specific user in a source directory tree, @@ -210,13 +219,15 @@ def replace_files_with_symlinks( ) # Use efficient scandir-based search - for file_path in find_owned_files_scandir(item_to_process, user_uid, inputdata_root): - replace_one_file_with_symlink(inputdata_root, target_dir, file_path, dry_run=dry_run) + for file_path in find_owned_files_scandir( + item_to_process, user_uid, inputdata_root + ): + replace_one_file_with_symlink( + inputdata_root, target_dir, file_path, dry_run=dry_run + ) -def replace_one_file_with_symlink( - inputdata_root, target_dir, file_path, dry_run=False -): +def replace_one_file_with_symlink(inputdata_root, target_dir, file_path, dry_run=False): """ Given a file, replaces it with a symbolic link to the same relative path in a target directory tree. @@ -227,7 +238,7 @@ def replace_one_file_with_symlink( file_path (str): The path of the file to be replaced. dry_run (bool): If True, only show what would be done without making changes. """ - logger.info("Found owned file: %s", file_path) + logger.info("'%s':", file_path) # Determine the relative path and the new link's destination relative_path = os.path.relpath(file_path, inputdata_root) @@ -236,9 +247,9 @@ def replace_one_file_with_symlink( # Check if the target file actually exists if not os.path.exists(link_target): logger.warning( - "Warning: Corresponding file '%s' not found for '%s'. Skipping.", + "%sWarning: Corresponding file '%s' not found. Skipping.", + INDENT, link_target, - file_path, ) return @@ -247,7 +258,8 @@ def replace_one_file_with_symlink( if dry_run: logger.info( - "[DRY RUN] Would create symbolic link: %s -> %s", + "%s[DRY RUN] Would create symbolic link: %s -> %s", + INDENT, link_name, link_target, ) @@ -256,9 +268,9 @@ def replace_one_file_with_symlink( # Remove the original file try: os.rename(link_name, link_name + ".tmp") - logger.info("Deleted original file: %s", link_name) + logger.info("%sDeleted original file: %s", INDENT, link_name) except OSError as e: - logger.error("Error deleting file %s: %s. Skipping.", link_name, e) + logger.error("%sError deleting file %s: %s. Skipping.", INDENT, link_name, e) return # Create the symbolic link, handling necessary parent directories @@ -267,52 +279,12 @@ def replace_one_file_with_symlink( os.makedirs(os.path.dirname(link_name), exist_ok=True) os.symlink(link_target, link_name) os.remove(link_name + ".tmp") - logger.info("Created symbolic link: %s -> %s", link_name, link_target) + logger.info("%sCreated symbolic link: %s -> %s", INDENT, link_name, link_target) except OSError as e: os.rename(link_name + ".tmp", link_name) - logger.error("Error creating symlink for %s: %s. Skipping.", link_name, e) - - -def validate_paths(path, check_is_dir=False): - """ - Validate that one or more paths exist. - - Args: - path (str or list): The path to validate, or a list of such paths. - - Returns: - str or list: The absolute path(s) if valid. - - Raises: - argparse.ArgumentTypeError: If a path doesn't exist. - """ - if isinstance(path, list): - result = [] - for item in path: - result.append(validate_paths(item, check_is_dir=check_is_dir)) - return result - - if not os.path.exists(path): - raise argparse.ArgumentTypeError(f"'{path}' does not exist") - if check_is_dir and not os.path.isdir(path): - raise argparse.ArgumentTypeError(f"'{path}' is not a directory") - return os.path.abspath(path) - - -def validate_directory(path): - """ - Validate that one or more directories exist. - - Args: - path (str or list): The directory to validate, or a list of such directories. - - Returns: - str or list: The absolute path(s) if valid. - - Raises: - argparse.ArgumentTypeError: If a path doesn't exist. - """ - return validate_paths(path, check_is_dir=True) + logger.error( + "%sError creating symlink for %s: %s. Skipping.", INDENT, link_name, e + ) def parse_arguments(): @@ -347,27 +319,11 @@ def parse_arguments(): ), ) - # The root of the directory tree containing CESM input data. - # ONLY INTENDED FOR USE IN TESTING - parser.add_argument( - "--inputdata-root", - "-inputdata", # to match rimport - type=validate_directory, - default=DEFAULT_INPUTDATA_ROOT, - help=argparse.SUPPRESS, - ) + # Add inputdata_root option flags + add_inputdata_root(parser) - # Verbosity options (mutually exclusive) - verbosity_group = parser.add_mutually_exclusive_group() - verbosity_group.add_argument( - "-v", "--verbose", action="store_true", help="Enable verbose output" - ) - verbosity_group.add_argument( - "-q", - "--quiet", - action="store_true", - help="Quiet mode (show only warnings and errors)", - ) + # Add verbosity options + add_parser_verbosity_group(parser) parser.add_argument( "--dry-run", @@ -397,15 +353,12 @@ def process_args(args): args (argparse.Namespace): Parsed command-line arguments. """ # Configure logging based on verbosity flags - if args.quiet: - args.log_level = logging.WARNING - elif args.verbose: - args.log_level = logging.DEBUG - else: - args.log_level = logging.INFO + args.log_level = get_log_level(quiet=args.quiet, verbose=args.verbose) # Ensure that items_to_process is a list - if hasattr(args, "items_to_process") and not isinstance(args.items_to_process, list): + if hasattr(args, "items_to_process") and not isinstance( + args.items_to_process, list + ): args.items_to_process = [args.items_to_process] # Check that everything is an absolute path (should have been converted, if needed, during @@ -438,7 +391,7 @@ def main(): args = parse_arguments() - logging.basicConfig(level=args.log_level, format="%(message)s", stream=sys.stdout) + configure_logging(args.log_level) my_username = os.environ["USER"] diff --git a/rimport b/rimport index 9c09239..3539cd3 100755 --- a/rimport +++ b/rimport @@ -9,6 +9,7 @@ Do `rimport --help` for more information. from __future__ import annotations import argparse +import logging import os import pwd import shutil @@ -19,13 +20,16 @@ from urllib.request import Request, urlopen from urllib.error import HTTPError import shared +INDENT = shared.INDENT DEFAULT_INPUTDATA_ROOT = Path(shared.DEFAULT_INPUTDATA_ROOT) DEFAULT_STAGING_ROOT = Path(shared.DEFAULT_STAGING_ROOT) STAGE_OWNER = "cesmdata" -INDENT = " " INPUTDATA_URL = "https://osdf-data.gdex.ucar.edu/ncar/gdex/d651077/cesmdata/inputdata" +# Configure logging +logger = shared.logger + def build_parser() -> argparse.ArgumentParser: """Build and configure the argument parser for rimport. @@ -41,42 +45,41 @@ def build_parser() -> argparse.ArgumentParser: argparse.ArgumentParser: Configured parser ready to parse command-line arguments. """ parser = argparse.ArgumentParser( - description="Copy files from CESM inputdata directory to a publishing directory.", + description=( + f"Copy files from CESM inputdata directory ({DEFAULT_INPUTDATA_ROOT}) to a publishing" + " directory." + ), add_help=False, # Disable automatic help to add custom -help flag ) - # Mutually exclusive: -file or -list (one required) - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument( + parser.add_argument( "--file", "-file", dest="file", metavar="filename", - help="Provide a single filename relative to the top inputdata directory", + help="Provide a file to import. Must be in the CESM inputdata directory.", ) - group.add_argument( + + parser.add_argument( "--list", "-list", dest="filelist", metavar="filelist", help=( "Provide a file that contains a list of filenames to import. All filenames in the list" - "are relative to the top inputdata area." + " must be in the CESM inputdata directory." ), ) parser.add_argument( - "--inputdata", - "-inputdata", - dest="inputdata", - metavar="inputdata_dir", - default=DEFAULT_INPUTDATA_ROOT, - help=( - "Change the default local top level inputdata directory." - f" Default: '{DEFAULT_INPUTDATA_ROOT}'" - ), + "items_to_process", + nargs="*", + help="One or more files to process. (Optional; can use --file instead to process just one.)" ) + # Add inputdata_root option flags + shared.add_inputdata_root(parser) + parser.add_argument( "--check", "-check", @@ -85,14 +88,11 @@ def build_parser() -> argparse.ArgumentParser: help="Check whether file(s) is/are already published.", ) - # Provide -help to mirror legacy behavior (in addition to -h and --help) - parser.add_argument( - "-h", - "--help", - "-help", - action="help", - help="Show this help message and exit", - ) + # Add verbosity options + shared.add_parser_verbosity_group(parser) + + # Add help text + shared.add_help(parser) return parser @@ -177,7 +177,7 @@ def stage_data( f"Source is a symlink, but target ({src.resolve()}) is outside staging directory " f"({staging_root})" ) - print(f"{INDENT}File is already published and linked.") + logger.info("%sFile is already published and linked.", INDENT) print_can_file_be_downloaded( can_file_be_downloaded(src.resolve(), staging_root) ) @@ -200,21 +200,19 @@ def stage_data( dst = staging_root / rel if dst.exists(): - print(f"{INDENT}File is already published but NOT linked; do") - print(f"{2*INDENT}relink.py {rel}") - print(f"{INDENT}to resolve.") - print_can_file_be_downloaded( - can_file_be_downloaded(rel, staging_root) - ) + logger.info("%sFile is already published but NOT linked; do", INDENT) + logger.info("%srelink.py %s", 2 * INDENT, rel) + logger.info("%sto resolve.", INDENT) + print_can_file_be_downloaded(can_file_be_downloaded(rel, staging_root)) return if check: - print(f"{INDENT}File is not already published") + logger.info("%sFile is not already published", INDENT) return dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(src, dst) - print(f"{INDENT}[rimport] staged {src} -> {dst}") + logger.info("%s[rimport] staged %s -> %s", INDENT, src, dst) def ensure_running_as(target_user: str, argv: list[str]) -> None: @@ -237,20 +235,18 @@ def ensure_running_as(target_user: str, argv: list[str]) -> None: try: target_uid = pwd.getpwnam(target_user).pw_uid except KeyError as exc: - print( - f"rimport: target user '{target_user}' not found on this system", - file=sys.stderr, - ) + logger.error("rimport: target user '%s' not found on this system", target_user) raise SystemExit(2) from exc if os.geteuid() != target_uid: try: assert sys.stdin.isatty() except AssertionError as exc: - print( - f"rimport: need interactive TTY to authenticate as '{target_user}' (2FA).\n" - f" Try: sudo -u {target_user} rimport …", - file=sys.stderr, + logger.error( + "rimport: need interactive TTY to authenticate as '%s' (2FA).\n" + " Try: sudo -u %s rimport …", + target_user, + target_user, ) raise SystemExit(2) from exc # Re-exec under target user; this invokes sudo’s normal password/2FA flow. @@ -310,9 +306,49 @@ def print_can_file_be_downloaded(file_can_be_downloaded: bool): file_can_be_downloaded: Boolean indicating if the file can be downloaded. """ if file_can_be_downloaded: - print(f"{INDENT}File is available for download.") + logger.info("%sFile is available for download.", INDENT) else: - print(f"{INDENT}File is not (yet) available for download.") + logger.info("%sFile is not (yet) available for download.", INDENT) + + +def get_files_to_process(file: str, filelist: str, items_to_process: list): + """Get list of files to process. + + Uses --file and/or --filelist arguments, as well as positional items_to_process if given. + + Args: + file (str): Single file to process. + filelist (str): File containing list of files to process. + items_to_process (list): List of files to process. + + Returns: + list: List of files to process + int: Result code + """ + if file is not None: + files_to_process = [file] + else: + files_to_process = [] + + if filelist is not None: + list_path = Path(filelist).expanduser().resolve() + if not list_path.exists(): + logger.error("rimport: list file not found: %s", list_path) + return None, 2 + files_in_list = read_filelist(list_path) + if not files_in_list: + logger.error("rimport: no filenames found in list: %s", list_path) + return None, 2 + files_to_process.extend(files_in_list) + + if items_to_process: + files_to_process.extend(items_to_process) + + if not files_to_process: + logger.error("rimport: At least one of --file or --filelist is required") + return None, 2 + + return files_to_process, 0 def main(argv: List[str] | None = None) -> int: @@ -340,43 +376,39 @@ def main(argv: List[str] | None = None) -> int: parser = build_parser() args = parser.parse_args(argv) + # Configure logging based on verbosity flags + log_level = shared.get_log_level(quiet=args.quiet, verbose=args.verbose) + shared.configure_logging(log_level) + # Ensure we are running as the STAGE_OWNER account before touching the tree # Set env var RIMPORT_SKIP_USER_CHECK=1 if you prefer to run `sudox -u STAGE_OWNER rimport …` # explicitly (or for testing). if not args.check and os.getenv("RIMPORT_SKIP_USER_CHECK") != "1": ensure_running_as(STAGE_OWNER, sys.argv) - root = Path(args.inputdata).expanduser().resolve() + root = Path(args.inputdata_root).expanduser().resolve() if not root.exists(): - print(f"rimport: inputdata directory does not exist: {root}", file=sys.stderr) + logger.error("rimport: inputdata directory does not exist: %s", root) return 2 # Determine the list of relative filenames to handle - if args.file is not None: - relnames = [args.file] - else: - list_path = Path(args.filelist).expanduser().resolve() - if not list_path.exists(): - print(f"rimport: list file not found: {list_path}", file=sys.stderr) - return 2 - relnames = read_filelist(list_path) - if not relnames: - print(f"rimport: no filenames found in list: {list_path}", file=sys.stderr) - return 2 + files_to_process, status = get_files_to_process(args.file, args.filelist, args.items_to_process) + if status: + return status # Resolve to full paths (keep accepting absolute names too) - paths = normalize_paths(root, relnames) + paths = normalize_paths(root, files_to_process) staging_root = get_staging_root() # Execute the new action per file errors = 0 for p in paths: - print(f"'{p}':") + logger.info("'%s':", p) try: stage_data(p, root, staging_root, args.check) except Exception as e: # pylint: disable=broad-exception-caught # General Exception keeps CLI robust for batch runs errors += 1 - print(f"{INDENT}rimport: error processing {p}: {e}", file=sys.stderr) + logger.error("%srimport: error processing %s: %s", INDENT, p, e) return 0 if errors == 0 else 1 diff --git a/shared.py b/shared.py index fad4df0..2750730 100644 --- a/shared.py +++ b/shared.py @@ -2,7 +2,174 @@ Things shared between rimport and relink """ +import logging +import argparse +import os +import sys + DEFAULT_INPUTDATA_ROOT = "/glade/campaign/cesm/cesmdata/cseg/inputdata/" DEFAULT_STAGING_ROOT = ( "/glade/campaign/collections/gdex/data/d651077/cesmdata/inputdata/" ) +INDENT = " " + +logger = logging.getLogger("rimport_relink") + + +def get_log_level(quiet: bool = False, verbose: bool = False) -> int: + """Determine logging level based on quiet and verbose flags. + + Args: + quiet: If True, show only warnings and errors (WARNING level). + verbose: If True, show debug messages (DEBUG level). + + Returns: + int: Logging level (DEBUG, INFO, or WARNING). + + Note: + If both quiet and verbose are True, quiet takes precedence. + """ + if quiet: + return logging.WARNING + if verbose: + return logging.DEBUG + return logging.INFO + + +def configure_logging(log_level: int, logger_in: logging.Logger = logger) -> None: + """Configure logging to send INFO/WARNING to stdout and ERROR/CRITICAL to stderr. + + Sets up two handlers: + - INFO handler: Sends INFO, WARNING, and DEBUG level messages to stdout + - ERROR handler: Sends ERROR and CRITICAL level messages to stderr + + Both handlers use simple message-only formatting without timestamps or level names. + + Args: + log_level: Minimum logging level (DEBUG, INFO, or WARNING). + logger_in: Logger to operate on. Should only be used in testing. + """ + logger_in.setLevel(log_level) + + # Handler for INFO, WARNING, and DEBUG level messages -> stdout + info_handler = logging.StreamHandler(sys.stdout) + info_handler.setLevel(logging.DEBUG) # Accept all levels, filter will handle it + info_handler.addFilter(lambda record: record.levelno < logging.ERROR) + info_handler.setFormatter(logging.Formatter("%(message)s")) + + # Handler for ERROR and CRITICAL level messages -> stderr + error_handler = logging.StreamHandler(sys.stderr) + error_handler.setLevel(logging.ERROR) + error_handler.setFormatter(logging.Formatter("%(message)s")) + + # Clear any existing handlers and add our custom ones + logger_in.handlers.clear() + logger_in.addHandler(info_handler) + logger_in.addHandler(error_handler) + + +def add_inputdata_root(parser: argparse.ArgumentParser): + """Add inputdata_root option to an argument parser. + + The root of the directory tree containing CESM input data. Only intended for use in testing, so + help is suppressed. + + Args: + parser: ArgumentParser instance to add the inputdata_root arg to. + """ + parser.add_argument( + "--inputdata-root", + "-inputdata-root", + "--inputdata", + "-inputdata", + "-i", + type=validate_directory, + default=DEFAULT_INPUTDATA_ROOT, + help=argparse.SUPPRESS, + ) + + +def add_help(parser: argparse.ArgumentParser): + """Add help option to an argument parser. + + Provides -help to mirror legacy rimport behavior (in addition to -h and --help). + + Args: + parser: ArgumentParser instance to add the help arg to. + """ + parser.add_argument( + "-h", + "--help", + "-help", + action="help", + help="Show this help message and exit", + ) + + +def add_parser_verbosity_group(parser: argparse.ArgumentParser): + """Add mutually exclusive verbosity options to an argument parser. + + Adds -v/--verbose and -q/--quiet flags as a mutually exclusive group. + + Args: + parser: ArgumentParser instance to add the verbosity group to. + + Returns: + The mutually exclusive argument group that was created. + """ + verbosity_group = parser.add_mutually_exclusive_group() + verbosity_group.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose output (DEBUG level)", + ) + verbosity_group.add_argument( + "-q", + "--quiet", + action="store_true", + help="Quiet mode (show only warnings and errors)", + ) + return verbosity_group + + +def validate_paths(path, check_is_dir=False): + """ + Validate that one or more paths exist. + + Args: + path (str or list): The path to validate, or a list of such paths. + + Returns: + str or list: The absolute path(s) if valid. + + Raises: + argparse.ArgumentTypeError: If a path doesn't exist. + """ + if isinstance(path, list): + result = [] + for item in path: + result.append(validate_paths(item, check_is_dir=check_is_dir)) + return result + + if not os.path.exists(path): + raise argparse.ArgumentTypeError(f"'{path}' does not exist") + if check_is_dir and not os.path.isdir(path): + raise argparse.ArgumentTypeError(f"'{path}' is not a directory") + return os.path.abspath(path) + + +def validate_directory(path): + """ + Validate that one or more directories exist. + + Args: + path (str or list): The directory to validate, or a list of such directories. + + Returns: + str or list: The absolute path(s) if valid. + + Raises: + argparse.ArgumentTypeError: If a path doesn't exist. + """ + return validate_paths(path, check_is_dir=True) diff --git a/tests/conftest.py b/tests/conftest.py index e1eff2d..0a7e232 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,33 @@ """ -Pytest configuration and shared fixtures for relink tests. +Pytest configuration and shared fixtures for all tests. """ import os +import tempfile +import shutil import pytest +from unittest.mock import patch @pytest.fixture(scope="session") def workspace_root(): """Return the root directory of the workspace.""" return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +@pytest.fixture(scope="function", name="temp_dirs") +def fixture_temp_dirs(): + """Create temporary source and target directories for testing.""" + source_dir = tempfile.mkdtemp(prefix="test_source_") + target_dir = tempfile.mkdtemp(prefix="test_target_") + + with patch("relink.DEFAULT_INPUTDATA_ROOT", source_dir): + with patch("relink.DEFAULT_STAGING_ROOT", target_dir): + with patch("shared.DEFAULT_INPUTDATA_ROOT", source_dir): + with patch("shared.DEFAULT_STAGING_ROOT", target_dir): + yield source_dir, target_dir + + # Cleanup + shutil.rmtree(source_dir, ignore_errors=True) + shutil.rmtree(target_dir, ignore_errors=True) diff --git a/tests/relink/conftest.py b/tests/relink/conftest.py index 73a4e19..f7fa29b 100644 --- a/tests/relink/conftest.py +++ b/tests/relink/conftest.py @@ -3,26 +3,8 @@ """ import os -import tempfile -import shutil import pytest -from unittest.mock import patch - - -@pytest.fixture(scope="function", name="temp_dirs") -def fixture_temp_dirs(): - """Create temporary source and target directories for testing.""" - source_dir = tempfile.mkdtemp(prefix="test_source_") - target_dir = tempfile.mkdtemp(prefix="test_target_") - - with patch("relink.DEFAULT_INPUTDATA_ROOT", source_dir): - with patch("relink.DEFAULT_STAGING_ROOT", target_dir): - yield source_dir, target_dir - - # Cleanup - shutil.rmtree(source_dir, ignore_errors=True) - shutil.rmtree(target_dir, ignore_errors=True) @pytest.fixture(name="current_user") diff --git a/tests/relink/test_args.py b/tests/relink/test_args.py index 38f9683..e44e739 100644 --- a/tests/relink/test_args.py +++ b/tests/relink/test_args.py @@ -193,145 +193,16 @@ def test_multiple_source_roots_with_target(self, temp_dirs): assert str(source2.resolve()) in args.items_to_process assert args.target_root == str(target.resolve()) - -class TestValidateDirectory: - """Test suite for validate_directory function.""" - - def test_valid_directory(self, tmp_path): - """Test that valid directory is accepted and returns absolute path.""" - test_dir = tmp_path / "valid_dir" - test_dir.mkdir() - - result = relink.validate_directory(str(test_dir)) - assert result == str(test_dir.resolve()) - - def test_nonexistent_directory(self): - """Test that nonexistent directory raises ArgumentTypeError.""" - nonexistent = os.path.join(os.sep, "nonexistent", "directory", "12345") - - with pytest.raises(argparse.ArgumentTypeError) as exc_info: - relink.validate_directory(nonexistent) - - assert "does not exist" in str(exc_info.value) - assert nonexistent in str(exc_info.value) - - def test_relative_path_converted_to_absolute(self, tmp_path): - """Test that relative paths are converted to absolute.""" - test_dir = tmp_path / "relative_test" - test_dir.mkdir() - - # Change to parent directory and use relative path - cwd = os.getcwd() - try: - os.chdir(str(tmp_path)) - result = relink.validate_directory("relative_test") - assert os.path.isabs(result) - assert result == str(test_dir.resolve()) - finally: - os.chdir(cwd) - - def test_symlink_to_directory(self, tmp_path): - """Test that symlink to a directory is accepted.""" - real_dir = tmp_path / "real_dir" - real_dir.mkdir() - - link_dir = tmp_path / "link_dir" - link_dir.symlink_to(real_dir) - - result = relink.validate_paths(str(link_dir)) - # validate_directory returns absolute path of the symlink itself - assert result == str(link_dir.absolute()) - # Verify it's still a symlink - assert os.path.islink(result) - - def test_list_with_invalid_directory(self, tmp_path): - """Test that a list with one invalid directory raises error.""" - dir1 = tmp_path / "dir1" - dir1.mkdir() - nonexistent = tmp_path / "nonexistent" - - with pytest.raises(argparse.ArgumentTypeError) as exc_info: - relink.validate_paths([str(dir1), str(nonexistent)]) - - assert "does not exist" in str(exc_info.value) - - -class TestValidatePaths: - """Test suite for validate_paths function.""" - - def test_valid_directory(self, tmp_path): - """Test that valid directory is accepted and returns absolute path.""" - test_dir = tmp_path / "valid_dir" - test_dir.mkdir() - - result = relink.validate_paths(str(test_dir)) - assert result == str(test_dir.resolve()) - - def test_nonexistent_directory(self): - """Test that nonexistent directory raises ArgumentTypeError.""" - nonexistent = os.path.join(os.sep, "nonexistent", "directory", "12345") - - with pytest.raises(argparse.ArgumentTypeError) as exc_info: - relink.validate_paths(nonexistent) - - assert "does not exist" in str(exc_info.value) - assert nonexistent in str(exc_info.value) - - def test_file_instead_of_directory(self, tmp_path): - """Test that a file path doesn't raise ArgumentTypeError (or any error).""" - test_file = tmp_path / "test_file.txt" - test_file.write_text("content") - - relink.validate_paths(str(test_file)) - - def test_relative_path_converted_to_absolute(self, tmp_path): - """Test that relative paths are converted to absolute.""" - test_dir = tmp_path / "relative_test" - test_dir.mkdir() - - # Change to parent directory and use relative path - cwd = os.getcwd() - try: - os.chdir(str(tmp_path)) - result = relink.validate_paths("relative_test") - assert os.path.isabs(result) - assert result == str(test_dir.resolve()) - finally: - os.chdir(cwd) - - def test_symlink_to_directory(self, tmp_path): - """Test that symlink to a directory is accepted.""" - real_dir = tmp_path / "real_dir" - real_dir.mkdir() - - link_dir = tmp_path / "link_dir" - link_dir.symlink_to(real_dir) - - result = relink.validate_paths(str(link_dir)) - # validate_directory returns absolute path of the symlink itself - assert result == str(link_dir.absolute()) - # Verify it's still a symlink - assert os.path.islink(result) - - def test_list_with_invalid_directory(self, tmp_path): - """Test that a list with one invalid directory raises error.""" - dir1 = tmp_path / "dir1" - dir1.mkdir() - nonexistent = tmp_path / "nonexistent" - - with pytest.raises(argparse.ArgumentTypeError) as exc_info: - relink.validate_paths([str(dir1), str(nonexistent)]) - - assert "does not exist" in str(exc_info.value) - - def test_list_with_file_instead_of_directory(self, tmp_path): - """Test that a list containing a file doesn't raise error.""" - dir1 = tmp_path / "dir1" - dir1.mkdir() - file1 = tmp_path / "file.txt" - file1.write_text("content") - - relink.validate_paths([str(dir1), str(file1)]) + @pytest.mark.parametrize( + "inputdata_flag", + ["-inputdata", "-i", "--inputdata", "--inputdata-root", "-inputdata-root"], + ) + def test_inputdata_arguments_accepted(self, temp_dirs, inputdata_flag): + """Test that all inputdata argument flags are accepted.""" + inputdata_root, _ = temp_dirs + with patch("sys.argv", ["relink.py", inputdata_flag, inputdata_root]): + args = relink.parse_arguments() + assert args.inputdata_root == inputdata_root class TestProcessArgs: diff --git a/tests/relink/test_cmdline.py b/tests/relink/test_cmdline.py index 44cd27c..46a03c8 100644 --- a/tests/relink/test_cmdline.py +++ b/tests/relink/test_cmdline.py @@ -9,6 +9,8 @@ import pytest +from shared import INDENT + @pytest.fixture(name="mock_dirs") def fixture_mock_dirs(tmp_path): @@ -57,7 +59,7 @@ def test_command_line_execution_dry_run(mock_dirs): # Verify dry-run messages in output assert "DRY RUN MODE" in result.stdout - assert "[DRY RUN] Would create symbolic link:" in result.stdout + assert f"{INDENT}[DRY RUN] Would create symbolic link:" in result.stdout # Verify no actual changes were made assert source_file.is_file() @@ -96,7 +98,7 @@ def test_command_line_execution_given_dir(mock_dirs): assert os.readlink(str(source_file)) == str(target_file) # Verify success messages in output - assert "Created symbolic link:" in result.stdout + assert f"{INDENT}Created symbolic link:" in result.stdout def test_command_line_execution_given_file(mock_dirs): @@ -131,7 +133,7 @@ def test_command_line_execution_given_file(mock_dirs): assert os.readlink(str(source_file)) == str(target_file) # Verify success messages in output - assert "Created symbolic link:" in result.stdout + assert f"{INDENT}Created symbolic link:" in result.stdout def test_command_line_multiple_source_dirs(temp_dirs): diff --git a/tests/relink/test_dryrun.py b/tests/relink/test_dryrun.py index ee638c4..5987ee4 100644 --- a/tests/relink/test_dryrun.py +++ b/tests/relink/test_dryrun.py @@ -16,6 +16,8 @@ # pylint: disable=wrong-import-position import relink # noqa: E402 +from shared import INDENT + @pytest.fixture(name="dry_run_setup") def fixture_dry_run_setup(temp_dirs): @@ -70,7 +72,7 @@ def test_dry_run_shows_message(dry_run_setup, caplog): # Check that dry-run messages were logged assert "DRY RUN MODE" in caplog.text - assert "[DRY RUN] Would create symbolic link:" in caplog.text + assert f"{INDENT}[DRY RUN] Would create symbolic link:" in caplog.text assert f"{source_file} -> {target_file}" in caplog.text @@ -85,7 +87,7 @@ def test_dry_run_no_delete_or_create_messages(dry_run_setup, caplog): ) # Verify actual operation messages are NOT logged - assert "Deleted original file:" not in caplog.text - assert "Created symbolic link:" not in caplog.text + assert f"{INDENT}Deleted original file:" not in caplog.text + assert f"{INDENT}Created symbolic link:" not in caplog.text # But the dry-run message should be there - assert "[DRY RUN] Would create symbolic link: " in caplog.text + assert f"{INDENT}[DRY RUN] Would create symbolic link: " in caplog.text diff --git a/tests/relink/test_replace_one_file_with_symlink.py b/tests/relink/test_replace_one_file_with_symlink.py index 70caaba..741f9f8 100644 --- a/tests/relink/test_replace_one_file_with_symlink.py +++ b/tests/relink/test_replace_one_file_with_symlink.py @@ -14,6 +14,8 @@ # pylint: disable=wrong-import-position import relink # noqa: E402 +from shared import INDENT + def test_basic_file_replacement(temp_dirs): """Test basic functionality: replace owned file with symlink.""" @@ -83,7 +85,7 @@ def test_missing_target_file(temp_dirs, caplog): assert os.path.isfile(source_file), "Original file should still exist" # Check warning message - assert "Warning: Corresponding file " in caplog.text + assert f"{INDENT}Warning: Corresponding file " in caplog.text assert " not found" in caplog.text @@ -117,7 +119,7 @@ def test_absolute_paths(temp_dirs): def test_print_found_owned_file(temp_dirs, caplog): - """Test that 'Found owned file' message is printed.""" + """Test that message with filename is printed.""" source_dir, target_dir = temp_dirs # Create a file owned by current user @@ -133,8 +135,8 @@ def test_print_found_owned_file(temp_dirs, caplog): with caplog.at_level(logging.INFO): relink.replace_one_file_with_symlink(source_dir, target_dir, source_file) - # Check that "Found owned file" message was logged - assert "Found owned file:" in caplog.text + # Check that message was logged + assert f"'{source_file}':" in caplog.text assert source_file in caplog.text @@ -156,8 +158,8 @@ def test_print_deleted_and_created_messages(temp_dirs, caplog): relink.replace_one_file_with_symlink(source_dir, target_dir, source_file) # Check messages - assert "Deleted original file:" in caplog.text - assert "Created symbolic link:" in caplog.text + assert f"{INDENT}Deleted original file:" in caplog.text + assert f"{INDENT}Created symbolic link:" in caplog.text assert f"{source_file} -> {target_file}" in caplog.text @@ -184,7 +186,7 @@ def mock_symlink(src, dst): relink.replace_one_file_with_symlink(source_dir, target_dir, source_file) # Check error message - assert "Error creating symlink" in caplog.text + assert f"{INDENT}Error creating symlink" in caplog.text assert source_file in caplog.text @@ -254,5 +256,5 @@ def mock_rename(src, dst): relink.replace_one_file_with_symlink(source_dir, target_dir, source_file) # Check error message - assert "Error deleting file" in caplog.text + assert f"{INDENT}Error deleting file" in caplog.text assert source_file in caplog.text diff --git a/tests/relink/test_verbosity.py b/tests/relink/test_verbosity.py index 0234b7d..b4bca5a 100644 --- a/tests/relink/test_verbosity.py +++ b/tests/relink/test_verbosity.py @@ -15,6 +15,8 @@ # pylint: disable=wrong-import-position import relink # noqa: E402 +from shared import INDENT + def test_quiet_mode_suppresses_info_messages(temp_dirs, caplog): """Test that quiet mode suppresses INFO level messages.""" @@ -45,8 +47,8 @@ def test_quiet_mode_suppresses_info_messages(temp_dirs, caplog): assert "Searching for files owned by" not in caplog.text assert "Skipping symlink:" not in caplog.text assert "Found owned file:" not in caplog.text - assert "Deleted original file:" not in caplog.text - assert "Created symbolic link:" not in caplog.text + assert f"{INDENT}Deleted original file:" not in caplog.text + assert f"{INDENT}Created symbolic link:" not in caplog.text def test_quiet_mode_shows_warnings(temp_dirs, caplog): @@ -66,7 +68,7 @@ def test_quiet_mode_shows_warnings(temp_dirs, caplog): ) # Verify WARNING message IS in the log - assert "Warning: Corresponding file" in caplog.text + assert f"{INDENT}Warning: Corresponding file" in caplog.text assert "not found" in caplog.text @@ -104,7 +106,7 @@ def mock_rename(src, dst): relink.replace_files_with_symlinks( source_dir, target_dir, username, inputdata_root=source_dir ) - assert "Error deleting file" in caplog.text + assert f"{INDENT}Error deleting file" in caplog.text # Clear the log for next test caplog.clear() @@ -126,4 +128,4 @@ def mock_symlink(src, dst): relink.replace_files_with_symlinks( source_dir, target_dir, username, inputdata_root=source_dir ) - assert "Error creating symlink" in caplog.text + assert f"{INDENT}Error creating symlink" in caplog.text diff --git a/tests/rimport/test_build_parser.py b/tests/rimport/test_build_parser.py index 6213dbd..b54d710 100644 --- a/tests/rimport/test_build_parser.py +++ b/tests/rimport/test_build_parser.py @@ -3,13 +3,14 @@ """ import os -import sys import argparse import importlib.util from importlib.machinery import SourceFileLoader import pytest +import shared + # Import rimport module from file without .py extension rimport_path = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), @@ -20,10 +21,22 @@ if spec is None: raise ImportError(f"Could not create spec for rimport from {rimport_path}") rimport = importlib.util.module_from_spec(spec) -sys.modules["rimport"] = rimport +# Don't add to sys.modules to avoid conflict with other test files loader.exec_module(rimport) +@pytest.fixture(scope="function", autouse=True) +def fixture_temp_inputdata(temp_dirs): + """ + Override rimport's DEFAULT_INPUTDATA_ROOT to ensure portability. + + We can't do it in tests/conftest.py's temp_dirs fixture because of how rimport is imported here; + we'd get "no module rimport" errors. However, we can use the tempdir that got set up in that + fixture. + """ + rimport.DEFAULT_INPUTDATA_ROOT = shared.DEFAULT_INPUTDATA_ROOT + + class TestBuildParser: """Test suite for build_parser() function.""" @@ -48,41 +61,54 @@ def test_list_arguments_accepted(self, list_flag): assert args.filelist == "files.txt" assert args.file is None - @pytest.mark.parametrize("inputdata_flag", ["-inputdata", "-i", "--inputdata"]) - def test_inputdata_arguments_accepted(self, inputdata_flag): + @pytest.mark.parametrize( + "inputdata_flag", + ["-inputdata", "-i", "--inputdata", "--inputdata-root", "-inputdata-root"], + ) + def test_inputdata_arguments_accepted(self, temp_dirs, inputdata_flag): """Test that all inputdata argument flags are accepted.""" + inputdata_root, _ = temp_dirs parser = rimport.build_parser() - inputdata_dir = "/some/dir" - args = parser.parse_args([inputdata_flag, inputdata_dir, "-f", "dummy_file.nc"]) - assert args.inputdata == inputdata_dir + args = parser.parse_args( + [inputdata_flag, inputdata_root, "-f", "dummy_file.nc"] + ) + assert args.inputdata_root == inputdata_root - def test_file_and_list_mutually_exclusive(self, capsys): - """Test that -file and -list cannot be used together.""" + def test_file_and_list_not_mutually_exclusive(self, capsys): + """Test that -file and -list can be used together.""" parser = rimport.build_parser() - with pytest.raises(SystemExit): - parser.parse_args(["-file", "test.txt", "-list", "files.txt"]) - - # Check that the error message explains the problem - captured = capsys.readouterr() - stderr_lines = captured.err.strip().split("\n") - assert "not allowed with argument" in stderr_lines[-1] - - def test_file_or_list_required(self, capsys): - """Test that either -file or -list is required.""" + file = "test.txt" + filelist = "files.txt" + args = parser.parse_args(["-file", file, "-list", filelist]) + assert args.file == file + assert args.filelist == filelist + assert args.items_to_process == [] + + def test_positional_items_to_process(self): + """Test that positional items_to_process are accepted""" parser = rimport.build_parser() - with pytest.raises(SystemExit): - parser.parse_args([]) + items_to_process = ["abc123", "def456"] + args = parser.parse_args(items_to_process) + assert args.file is None + assert args.filelist is None + assert args.items_to_process == items_to_process - # Check that the error message explains the problem - captured = capsys.readouterr() - stderr_lines = captured.err.strip().split("\n") - assert "error: one of the arguments" in stderr_lines[-1] + def test_file_and_list_ok_with_items_to_process(self): + """Test that -file and -list can be used together with items_to_process""" + parser = rimport.build_parser() + file = "test.txt" + filelist = "files.txt" + items_to_process = ["abc123", "def456"] + args = parser.parse_args(["-file", file, "-list", filelist, *items_to_process]) + assert args.file == file + assert args.filelist == filelist + assert args.items_to_process == items_to_process def test_inputdata_default(self): """Test that -inputdata has correct default value.""" parser = rimport.build_parser() args = parser.parse_args(["-file", "test.txt"]) - assert args.inputdata == rimport.DEFAULT_INPUTDATA_ROOT + assert args.inputdata_root == str(rimport.DEFAULT_INPUTDATA_ROOT) def test_check_default(self): """Test that --check has the correct default value.""" @@ -97,12 +123,14 @@ def test_check_arguments_accepted(self, check_flag): args = parser.parse_args(["-file", "test.txt", check_flag]) assert args.check is True - def test_inputdata_custom(self): + def test_inputdata_custom(self, temp_dirs): """Test that -inputdata can be customized.""" parser = rimport.build_parser() - custom_path = "/custom/path" + inputdata_root, _ = temp_dirs + custom_path = os.path.join(inputdata_root, "custom", "path") + os.makedirs(custom_path) args = parser.parse_args(["-file", "test.txt", "-inputdata", custom_path]) - assert args.inputdata == custom_path + assert args.inputdata_root == custom_path @pytest.mark.parametrize("help_flag", ["-help", "-h", "--help"]) def test_help_flags_show_help(self, help_flag): @@ -113,16 +141,61 @@ def test_help_flags_show_help(self, help_flag): # Help should exit with code 0 assert exc_info.value.code == 0 - def test_file_with_inputdata(self): + def test_file_with_inputdata(self, temp_dirs): """Test combining -file with -inputdata.""" parser = rimport.build_parser() - args = parser.parse_args(["-file", "data.nc", "-inputdata", "/my/data"]) + inputdata_root, _ = temp_dirs + custom_path = os.path.join(inputdata_root, "custom", "path2") + os.makedirs(custom_path) + args = parser.parse_args(["-file", "data.nc", "-inputdata", custom_path]) assert args.file == "data.nc" - assert args.inputdata == "/my/data" + assert args.inputdata_root == custom_path - def test_list_with_inputdata(self): + def test_list_with_inputdata(self, temp_dirs): """Test combining -list with -inputdata.""" parser = rimport.build_parser() - args = parser.parse_args(["-list", "files.txt", "-inputdata", "/my/data"]) + inputdata_root, _ = temp_dirs + custom_path = os.path.join(inputdata_root, "custom", "path3") + os.makedirs(custom_path) + args = parser.parse_args(["-list", "files.txt", "-inputdata", custom_path]) assert args.filelist == "files.txt" - assert args.inputdata == "/my/data" + assert args.inputdata_root == custom_path + + def test_quiet_default(self): + """Test that quiet defaults to False.""" + parser = rimport.build_parser() + args = parser.parse_args(["-file", "test.nc"]) + assert args.quiet is False + + def test_verbose_default(self): + """Test that verbose defaults to False.""" + parser = rimport.build_parser() + args = parser.parse_args(["-file", "test.nc"]) + assert args.verbose is False + + @pytest.mark.parametrize("quiet_flag", ["-q", "--quiet"]) + def test_quiet_arguments_accepted(self, quiet_flag): + """Test that all quiet argument flags are accepted.""" + parser = rimport.build_parser() + args = parser.parse_args(["-file", "test.nc", quiet_flag]) + assert args.quiet is True + assert args.verbose is False + + @pytest.mark.parametrize("verbose_flag", ["-v", "--verbose"]) + def test_verbose_arguments_accepted(self, verbose_flag): + """Test that all verbose argument flags are accepted.""" + parser = rimport.build_parser() + args = parser.parse_args(["-file", "test.nc", verbose_flag]) + assert args.verbose is True + assert args.quiet is False + + def test_quiet_and_verbose_mutually_exclusive(self, capsys): + """Test that -q and -v cannot be used together.""" + parser = rimport.build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["-file", "test.nc", "-q", "-v"]) + + # Check that the error message explains the problem + captured = capsys.readouterr() + stderr_lines = captured.err.strip().split("\n") + assert "not allowed with argument" in stderr_lines[-1] diff --git a/tests/rimport/test_can_file_be_downloaded.py b/tests/rimport/test_can_file_be_downloaded.py index 30fcfa4..85e3d43 100644 --- a/tests/rimport/test_can_file_be_downloaded.py +++ b/tests/rimport/test_can_file_be_downloaded.py @@ -3,7 +3,6 @@ """ import os -import sys import importlib.util from importlib.machinery import SourceFileLoader from pathlib import Path @@ -22,7 +21,7 @@ if spec is None: raise ImportError(f"Could not create spec for rimport from {rimport_path}") rimport = importlib.util.module_from_spec(spec) -sys.modules["rimport"] = rimport +# Don't add to sys.modules to avoid conflict with other test files loader.exec_module(rimport) RELPATH_THAT_DOES_EXIST = os.path.join( @@ -33,7 +32,9 @@ class TestCanFileBeDownloaded: """Test suite for can_file_be_downloaded() function.""" - @pytest.mark.skipif(not os.path.exists("/glade"), reason="This test can only run on Glade") + @pytest.mark.skipif( + not os.path.exists("/glade"), reason="This test can only run on Glade" + ) def test_existing_file_exists(self): """Test that the file that should exist does. If not, other tests will definitely fail.""" file_abspath = Path(os.path.join(DEFAULT_STAGING_ROOT, RELPATH_THAT_DOES_EXIST)) diff --git a/tests/rimport/test_ensure_running_as.py b/tests/rimport/test_ensure_running_as.py index 07b8727..135aabf 100644 --- a/tests/rimport/test_ensure_running_as.py +++ b/tests/rimport/test_ensure_running_as.py @@ -3,7 +3,6 @@ """ import os -import sys import importlib.util from importlib.machinery import SourceFileLoader from unittest.mock import patch, MagicMock @@ -20,7 +19,7 @@ if spec is None: raise ImportError(f"Could not create spec for rimport from {rimport_path}") rimport = importlib.util.module_from_spec(spec) -sys.modules["rimport"] = rimport +# Don't add to sys.modules to avoid conflict with other test files loader.exec_module(rimport) @@ -42,7 +41,9 @@ def test_does_nothing_when_already_running_as_target_user(self): with patch("sys.stdin.isatty") as mock_isatty: with patch("os.execvp") as mock_execvp: # Should not raise or exec - rimport.ensure_running_as("testuser", ["rimport", "-file", "test.nc"]) + rimport.ensure_running_as( + "testuser", ["rimport", "-file", "test.nc"] + ) # Verify stdin.isatty and os.execvp were NOT called mock_isatty.assert_not_called() @@ -77,7 +78,7 @@ def test_execs_sudo_when_different_user_and_interactive(self): assert call_args[1][3] == "--" assert call_args[1][4:] == ["rimport", "-file", "test.nc"] - def test_error_message_for_nonexistent_user(self, capsys): + def test_error_message_for_nonexistent_user(self, caplog): """Test that appropriate error message is shown for nonexistent user.""" # Mock pwd.getpwnam to raise KeyError with patch("pwd.getpwnam", side_effect=KeyError("user not found")): @@ -85,11 +86,10 @@ def test_error_message_for_nonexistent_user(self, capsys): rimport.ensure_running_as("baduser", ["rimport", "-file", "test.nc"]) assert exc_info.value.code == 2 - captured = capsys.readouterr() - assert "baduser" in captured.err - assert "not found" in captured.err + assert "baduser" in caplog.text + assert "not found" in caplog.text - def test_error_message_for_non_interactive(self, capsys): + def test_error_message_for_non_interactive(self, caplog): """Test that appropriate error message is shown when not interactive.""" current_uid = os.geteuid() different_uid = current_uid + 1000 @@ -108,6 +108,5 @@ def test_error_message_for_non_interactive(self, capsys): ) assert exc_info.value.code == 2 - captured = capsys.readouterr() - assert "interactive TTY" in captured.err - assert "2FA" in captured.err + assert "interactive TTY" in caplog.text + assert "2FA" in caplog.text diff --git a/tests/rimport/test_get_files_to_process.py b/tests/rimport/test_get_files_to_process.py new file mode 100644 index 0000000..6fc13b3 --- /dev/null +++ b/tests/rimport/test_get_files_to_process.py @@ -0,0 +1,331 @@ +""" +Tests for get_files_to_process function in rimport script. +""" + +import os +import importlib.util +from importlib.machinery import SourceFileLoader + +# pylint: disable=too-many-arguments,too-many-positional-arguments + +# Import rimport module from file without .py extension +rimport_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "rimport", +) +loader = SourceFileLoader("rimport", rimport_path) +spec = importlib.util.spec_from_loader("rimport", loader) +if spec is None: + raise ImportError(f"Could not create spec for rimport from {rimport_path}") +rimport = importlib.util.module_from_spec(spec) +# Don't add to sys.modules to avoid conflict with other test files (patches here not being applied) +loader.exec_module(rimport) + + +class TestGetRelnamesToProcess: + """Test suite for get_relnames_to_process() function.""" + + def test_single_file_relpath(self, tmp_path): + """Test giving it a single file by its relative path""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filename = "test.nc" + test_file = inputdata_root / filename + test_file.write_text("abc123") + + # Run + files_to_process, result = rimport.get_files_to_process( + file=filename, + filelist=None, + items_to_process=None, + ) + + # Verify + assert result == 0 + assert files_to_process == [filename] + + def test_single_file_abspath(self, tmp_path): + """Test giving it a single file by its absolute path""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filename = "test.nc" + test_file = inputdata_root / filename + test_file.write_text("abc123") + + # Run + files_to_process, result = rimport.get_files_to_process( + file=test_file, + filelist=None, + items_to_process=None, + ) + + # Verify + assert result == 0 + assert files_to_process == [test_file] + + def test_filelist_relpath_with_relpaths(self, tmp_path): + """Test giving it a file list by its relative path, containing relative paths""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filenames = [] + for i in range(2): + filename = f"test{i}.txt" + filenames.append(filename) + (inputdata_root / filename).write_text("def567") + + filelist = tmp_path / "file_list.txt" + filelist.write_text("\n".join(filenames), encoding="utf8") + filelist_relpath = os.path.relpath(filelist) + + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=filelist_relpath, + items_to_process=None, + ) + + # Verify + assert result == 0 + assert files_to_process == filenames + + def test_filelist_abspath_with_relpaths(self, tmp_path): + """Test giving it a file list by its absolute path, containing relative paths""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filenames = [] + for i in range(2): + filename = f"test{i}.txt" + filenames.append(filename) + (inputdata_root / filename).write_text("def567") + + filelist = tmp_path / "file_list.txt" + filelist.write_text("\n".join(filenames), encoding="utf8") + + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=filelist, + items_to_process=None, + ) + + # Verify + assert result == 0 + assert files_to_process == filenames + + def test_filelist_relpath_with_abspaths(self, tmp_path): + """Test giving it a file list by its relative path, containing absolute paths""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filenames = [] + for i in range(2): + filename = inputdata_root / f"test{i}.txt" + filenames.append(str(filename)) + filename.write_text("def567") + + filelist = tmp_path / "file_list.txt" + filelist.write_text("\n".join(filenames), encoding="utf8") + filelist_relpath = os.path.relpath(filelist) + + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=filelist_relpath, + items_to_process=None, + ) + + # Verify + assert result == 0 + assert files_to_process == filenames + + def test_filelist_abspath_with_abspaths(self, tmp_path): + """Test giving it a file list by its absolute path, containing absolute paths""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filenames = [] + for i in range(2): + filename = inputdata_root / f"test{i}.txt" + filenames.append(str(filename)) + filename.write_text("def567") + + filelist = tmp_path / "file_list.txt" + filelist.write_text("\n".join(filenames), encoding="utf8") + + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=filelist, + items_to_process=None, + ) + + # Verify + assert result == 0 + assert files_to_process == filenames + + def test_filelist_not_found(self): + """Test giving it a file list that doesn't exist""" + filelist = "bsfearirn" + assert not os.path.exists(filelist) + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=filelist, + items_to_process=None, + ) + assert result == 2 + assert files_to_process is None + + def test_filelist_empty(self, tmp_path): + """Test giving it an empty file list""" + filelist = tmp_path / "bsfearirn" + filelist.write_text("") + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=filelist, + items_to_process=[], + ) + assert result == 2 + assert files_to_process is None + + def test_items_to_process_abspaths(self, tmp_path): + """Test giving it a list of absolute paths in items_to_process""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filenames = [] + for i in range(2): + filename = inputdata_root / f"test{i}.txt" + filenames.append(str(filename)) + filename.write_text("def567") + + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=None, + items_to_process=filenames, + ) + + # Verify + assert result == 0 + assert files_to_process == filenames + + def test_items_to_process_relpaths(self, tmp_path): + """Test giving it a list of relative paths in items_to_process""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + + filenames = [] + for i in range(2): + filename = inputdata_root / f"test{i}.txt" + filenames.append(os.path.basename(filename)) + filename.write_text("def567") + + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=None, + items_to_process=filenames, + ) + + # Verify + assert result == 0 + assert files_to_process == filenames + + def test_items_to_process_mixpaths(self, tmp_path): + """Test giving it a list of absolute and relative paths in items_to_process""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + + filenames = [] + for i in range(2): + filename = inputdata_root / f"test{i}.txt" + filenames.append(os.path.basename(filename)) + filename.write_text("def567") + for i in range(2): + filename = inputdata_root / f"test{2*i}.txt" + filenames.append(str(filename)) + filename.write_text("def567") + assert len(filenames) == 4 + + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=None, + items_to_process=filenames, + ) + + # Verify + assert result == 0 + assert files_to_process == filenames + + def test_single_file_and_list(self, tmp_path): + """Test giving it a single file by its relative path""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + filename = "test.nc" + test_file = inputdata_root / filename + test_file.write_text("abc123") + + filenames = [] + for i in range(2): + f = f"test{i}.txt" + filenames.append(f) + (inputdata_root / f).write_text("def567") + + filelist = tmp_path / "file_list.txt" + filelist.write_text("\n".join(filenames), encoding="utf8") + + # Run + files_to_process, result = rimport.get_files_to_process( + file=filename, + filelist=filelist, + items_to_process=None, + ) + + # Verify + assert result == 0 + assert files_to_process == [filename] + filenames + + def test_single_or_filelist_or_list_required(self): + """Test that at least one of file, filelist, items_to_process is required""" + # Run + files_to_process, result = rimport.get_files_to_process( + file=None, + filelist=None, + items_to_process=None, + ) + + # Verify + assert result == 2 + assert files_to_process is None diff --git a/tests/rimport/test_get_staging_root.py b/tests/rimport/test_get_staging_root.py index d87732e..53f9ee3 100644 --- a/tests/rimport/test_get_staging_root.py +++ b/tests/rimport/test_get_staging_root.py @@ -3,11 +3,9 @@ """ import os -import sys import importlib.util from importlib.machinery import SourceFileLoader from pathlib import Path -from unittest.mock import patch import pytest @@ -21,7 +19,7 @@ if spec is None: raise ImportError(f"Could not create spec for rimport from {rimport_path}") rimport = importlib.util.module_from_spec(spec) -sys.modules["rimport"] = rimport +# Don't add to sys.modules to avoid conflict with other test files loader.exec_module(rimport) @@ -48,62 +46,62 @@ def wrapper(*args, **kwargs): class TestGetStagingRoot: """Test suite for get_staging_root() function.""" - def test_returns_default_when_env_not_set(self): + def test_returns_default_when_env_not_set(self, monkeypatch): """Test that default staging root is returned when RIMPORT_STAGING is not set.""" # Ensure RIMPORT_STAGING is not set - with patch.dict(os.environ, {}, clear=True): + monkeypatch.delenv("RIMPORT_STAGING", raising=False) - result = rimport.get_staging_root() + result = rimport.get_staging_root() - assert result == rimport.DEFAULT_STAGING_ROOT + assert result == rimport.DEFAULT_STAGING_ROOT - def test_returns_env_value_when_set(self, tmp_path): + def test_returns_env_value_when_set(self, tmp_path, monkeypatch): """Test that RIMPORT_STAGING environment variable is used when set.""" custom_staging = tmp_path / "custom_staging" custom_staging.mkdir() - with patch.dict(os.environ, {"RIMPORT_STAGING": str(custom_staging)}): - result = rimport.get_staging_root() + monkeypatch.setenv("RIMPORT_STAGING", str(custom_staging)) + result = rimport.get_staging_root() - assert result == custom_staging.resolve() + assert result == custom_staging.resolve() - def test_expands_tilde_in_env_value(self): + def test_expands_tilde_in_env_value(self, monkeypatch): """Test that ~ is expanded in RIMPORT_STAGING value.""" # Use a path with ~ that will be expanded - with patch.dict(os.environ, {"RIMPORT_STAGING": "~/my_staging"}): - result = rimport.get_staging_root() + monkeypatch.setenv("RIMPORT_STAGING", "~/my_staging") + result = rimport.get_staging_root() - # Should be expanded and resolved - assert "~" not in str(result) - assert result.is_absolute() + # Should be expanded and resolved + assert "~" not in str(result) + assert result.is_absolute() - def test_resolves_relative_path_in_env_value(self): + def test_resolves_relative_path_in_env_value(self, monkeypatch): """Test that relative paths in RIMPORT_STAGING are resolved.""" # Set a relative path - with patch.dict(os.environ, {"RIMPORT_STAGING": "./staging"}): - result = rimport.get_staging_root() + monkeypatch.setenv("RIMPORT_STAGING", "./staging") + result = rimport.get_staging_root() - # Should be resolved to absolute path - assert result.is_absolute() + # Should be resolved to absolute path + assert result.is_absolute() - def test_env_value_with_spaces(self, tmp_path): + def test_env_value_with_spaces(self, tmp_path, monkeypatch): """Test handling of RIMPORT_STAGING with spaces in path.""" custom_staging = tmp_path / "staging with spaces" custom_staging.mkdir() - with patch.dict(os.environ, {"RIMPORT_STAGING": str(custom_staging)}): - result = rimport.get_staging_root() + monkeypatch.setenv("RIMPORT_STAGING", str(custom_staging)) + result = rimport.get_staging_root() - assert result == custom_staging.resolve() + assert result == custom_staging.resolve() - def test_env_value_overrides_default(self, tmp_path): + def test_env_value_overrides_default(self, tmp_path, monkeypatch): """Test that RIMPORT_STAGING overrides the default value.""" custom_staging = tmp_path / "override" custom_staging.mkdir() - with patch.dict(os.environ, {"RIMPORT_STAGING": str(custom_staging)}): - result = rimport.get_staging_root() + monkeypatch.setenv("RIMPORT_STAGING", str(custom_staging)) + result = rimport.get_staging_root() - # Should NOT be the default - assert result != rimport.DEFAULT_STAGING_ROOT - assert result == custom_staging.resolve() + # Should NOT be the default + assert result != rimport.DEFAULT_STAGING_ROOT + assert result == custom_staging.resolve() diff --git a/tests/rimport/test_main.py b/tests/rimport/test_main.py new file mode 100644 index 0000000..b35ae2d --- /dev/null +++ b/tests/rimport/test_main.py @@ -0,0 +1,369 @@ +""" +Tests for main() function in rimport script. + +These tests focus on the logic and control flow in main(), mocking out +the helper functions to isolate main()'s behavior. +""" + +import os +import importlib.util +from importlib.machinery import SourceFileLoader +from unittest.mock import patch, call +import pytest + +# pylint: disable=too-many-arguments,too-many-positional-arguments + +# Import rimport module from file without .py extension +rimport_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "rimport", +) +loader = SourceFileLoader("rimport", rimport_path) +spec = importlib.util.spec_from_loader("rimport", loader) +if spec is None: + raise ImportError(f"Could not create spec for rimport from {rimport_path}") +rimport = importlib.util.module_from_spec(spec) +# Don't add to sys.modules to avoid conflict with other test files (patches here not being applied) +loader.exec_module(rimport) + + +class TestMain: + """Test suite for main() function.""" + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "ensure_running_as") + def test_single_file_success( + self, + _mock_ensure_running_as, + mock_normalize_paths, + mock_get_staging_root, + mock_stage_data, + tmp_path, + ): + """Test main() logic flow when a single file stages successfully.""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + mock_get_staging_root.return_value = staging_root + test_file = inputdata_root / "test.nc" + mock_normalize_paths.return_value = [test_file] + + # Run + result = rimport.main(["-file", "test.nc", "-inputdata", str(inputdata_root)]) + + # Verify + assert result == 0 + mock_normalize_paths.assert_called_once_with(inputdata_root, ["test.nc"]) + mock_stage_data.assert_called_once_with( + test_file, inputdata_root, staging_root, False + ) + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "read_filelist") + @patch.object(rimport, "ensure_running_as") + def test_file_list_success( + self, + _mock_ensure_running_as, + mock_read_filelist, + mock_normalize_paths, + mock_get_staging_root, + mock_stage_data, + tmp_path, + ): + """Test main() logic flow when a file list stages successfully.""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + filelist = tmp_path / "files.txt" + filelist.write_text("file1.nc\nfile2.nc\n") + + mock_get_staging_root.return_value = staging_root + mock_read_filelist.return_value = ["file1.nc", "file2.nc"] + file1 = inputdata_root / "file1.nc" + file2 = inputdata_root / "file2.nc" + mock_normalize_paths.return_value = [file1, file2] + + # Run + result = rimport.main( + ["-list", str(filelist), "-inputdata", str(inputdata_root)] + ) + + # Verify + assert result == 0 + mock_read_filelist.assert_called_once_with(filelist) + mock_normalize_paths.assert_called_once_with( + inputdata_root, ["file1.nc", "file2.nc"] + ) + assert mock_stage_data.call_count == 2 + mock_stage_data.assert_has_calls( + [ + call(file1, inputdata_root, staging_root, False), + call(file2, inputdata_root, staging_root, False), + ] + ) + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "ensure_running_as") + def test_stage_data_exception_handling( + self, + _mock_ensure_running_as, + mock_normalize_paths, + _mock_get_staging_root, + mock_stage_data, + tmp_path, + capsys, + ): + """Test that main() handles exceptions from stage_data and continues processing.""" + # Setup + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + + file1 = inputdata_root / "file1.nc" + file2 = inputdata_root / "file2.nc" + file3 = inputdata_root / "file3.nc" + mock_normalize_paths.return_value = [file1, file2, file3] + + # Make stage_data fail for file2 but succeed for others + def stage_data_side_effect(src, *_args, **_kwargs): + if src == file2: + raise RuntimeError("Test error for file2") + + mock_stage_data.side_effect = stage_data_side_effect + + # Run + result = rimport.main(["-file", "test.nc", "-inputdata", str(inputdata_root)]) + + # Verify + assert result == 1 # Should return 1 because of error + assert mock_stage_data.call_count == 3 # All files should be attempted + + # Check that error was printed to stderr + captured = capsys.readouterr() + assert "error processing" in captured.err + assert "Test error for file2" in captured.err + + @patch.object(rimport, "ensure_running_as") + def test_nonexistent_inputdata_directory( + self, _mock_ensure_running_as, tmp_path, capsys + ): + """Test that argument parser rejects nonexistent inputdata directory.""" + nonexistent = tmp_path / "nonexistent" + + with pytest.raises(SystemExit) as exc_info: + rimport.main(["-file", "test.nc", "-inputdata", str(nonexistent)]) + + assert exc_info.value.code == 2 + captured = capsys.readouterr() + assert "does not exist" in captured.err + + @patch.object(rimport, "ensure_running_as") + def test_nonexistent_filelist(self, _mock_ensure_running_as, tmp_path, capsys): + """Test that main() returns error code 2 for nonexistent file list.""" + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + nonexistent_list = tmp_path / "nonexistent.txt" + + result = rimport.main( + ["-list", str(nonexistent_list), "-inputdata", str(inputdata_root)] + ) + + assert result == 2 + captured = capsys.readouterr() + assert "list file not found" in captured.err + + @patch.object(rimport, "read_filelist") + @patch.object(rimport, "ensure_running_as") + def test_empty_filelist( + self, _mock_ensure_running_as, mock_read_filelist, tmp_path, capsys + ): + """Test that main() returns error code 2 for empty file list.""" + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + filelist = tmp_path / "empty.txt" + filelist.write_text("") + + mock_read_filelist.return_value = [] + + result = rimport.main( + ["-list", str(filelist), "-inputdata", str(inputdata_root)] + ) + + assert result == 2 + captured = capsys.readouterr() + assert "no filenames found in list" in captured.err + + @patch.object(rimport, "ensure_running_as") + def test_requires_file_or_filelist(self, _mock_ensure_running_as, tmp_path, capsys): + """Test that main() returns error code 2 if neither file nor filelist provided.""" + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + + result = rimport.main(["-inputdata", str(inputdata_root)]) + + assert result == 2 + captured = capsys.readouterr() + assert "At least one of --file or --filelist is required" in captured.err + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "ensure_running_as") + def test_check_mode_calls( + self, + mock_ensure_running_as, + mock_normalize_paths, + mock_get_staging_root, + mock_stage_data, + tmp_path, + ): + """Test that --check mode skips the user check but does call stage_data.""" + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + staging_root = tmp_path / "staging" + staging_root.mkdir() + + mock_get_staging_root.return_value = staging_root + test_file = inputdata_root / "test.nc" + mock_normalize_paths.return_value = [test_file] + + result = rimport.main( + ["-file", "test.nc", "-inputdata", str(inputdata_root), "--check"] + ) + + assert result == 0 + # ensure_running_as should NOT be called in check mode + mock_ensure_running_as.assert_not_called() + # stage_data should be called with check=True + mock_stage_data.assert_called_once_with( + test_file, inputdata_root, staging_root, True + ) + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "ensure_running_as") + def test_skip_user_check_env_var( + self, + mock_ensure_running_as, + mock_normalize_paths, + _mock_get_staging_root, + _mock_stage, + tmp_path, + monkeypatch, + ): + """Test that RIMPORT_SKIP_USER_CHECK=1 skips the user check.""" + monkeypatch.setenv("RIMPORT_SKIP_USER_CHECK", "1") + + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + + test_file = inputdata_root / "test.nc" + mock_normalize_paths.return_value = [test_file] + + result = rimport.main(["-file", "test.nc", "-inputdata", str(inputdata_root)]) + + assert result == 0 + # ensure_running_as should NOT be called when env var is set + mock_ensure_running_as.assert_not_called() + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "ensure_running_as") + def test_prints_file_path_before_processing( + self, + _mock_ensure_running_as, + mock_normalize_paths, + _mock_get_staging_root, + _mock_stage, + tmp_path, + capsys, + ): + """Test that main() prints each file path before processing.""" + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + file1 = inputdata_root / "file1.nc" + file2 = inputdata_root / "file2.nc" + mock_normalize_paths.return_value = [file1, file2] + + result = rimport.main(["-file", "test.nc", "-inputdata", str(inputdata_root)]) + + assert result == 0 + captured = capsys.readouterr() + # Check that file paths are printed with quotes + assert f"'{file1}':" in captured.out + assert f"'{file2}':" in captured.out + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "ensure_running_as") + def test_multiple_errors_returns_1( + self, + _mock_ensure_running_as, + mock_normalize_paths, + _mock_get_staging_root, + mock_stage_data, + tmp_path, + ): + """Test that main() returns 1 when multiple files fail.""" + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + + file1 = inputdata_root / "file1.nc" + file2 = inputdata_root / "file2.nc" + file3 = inputdata_root / "file3.nc" + mock_normalize_paths.return_value = [file1, file2, file3] + + # Make all files fail + mock_stage_data.side_effect = RuntimeError("Test error") + + result = rimport.main(["-file", "test.nc", "-inputdata", str(inputdata_root)]) + + assert result == 1 + assert mock_stage_data.call_count == 3 + + @patch.object(rimport, "stage_data") + @patch.object(rimport, "get_staging_root") + @patch.object(rimport, "normalize_paths") + @patch.object(rimport, "ensure_running_as") + def test_error_counter_increments_correctly( + self, + _mock_ensure_running_as, + mock_normalize_paths, + _mock_get_staging_root, + mock_stage_data, + tmp_path, + capsys, + ): + """Test that the error counter increments for each failed file.""" + inputdata_root = tmp_path / "inputdata" + inputdata_root.mkdir() + + files = [inputdata_root / f"file{i}.nc" for i in range(5)] + mock_normalize_paths.return_value = files + + # Make files 1 and 3 fail + def stage_data_side_effect(src, *_args, **_kwargs): + if src in [files[1], files[3]]: + raise RuntimeError(f"Test error for {src.name}") + + mock_stage_data.side_effect = stage_data_side_effect + + result = rimport.main(["-file", "test.nc", "-inputdata", str(inputdata_root)]) + + assert result == 1 + captured = capsys.readouterr() + # Should have 2 error messages + assert captured.err.count("error processing") == 2 diff --git a/tests/rimport/test_normalize_paths.py b/tests/rimport/test_normalize_paths.py index 594569a..29d3c15 100644 --- a/tests/rimport/test_normalize_paths.py +++ b/tests/rimport/test_normalize_paths.py @@ -3,7 +3,6 @@ """ import os -import sys import importlib.util from importlib.machinery import SourceFileLoader from pathlib import Path @@ -20,7 +19,7 @@ if spec is None: raise ImportError(f"Could not create spec for rimport from {rimport_path}") rimport = importlib.util.module_from_spec(spec) -sys.modules["rimport"] = rimport +# Don't add to sys.modules to avoid conflict with other test files loader.exec_module(rimport) diff --git a/tests/rimport/test_read_filelist.py b/tests/rimport/test_read_filelist.py index 4ec39da..af7f904 100644 --- a/tests/rimport/test_read_filelist.py +++ b/tests/rimport/test_read_filelist.py @@ -3,7 +3,6 @@ """ import os -import sys import importlib.util from importlib.machinery import SourceFileLoader @@ -17,7 +16,7 @@ if spec is None: raise ImportError(f"Could not create spec for rimport from {rimport_path}") rimport = importlib.util.module_from_spec(spec) -sys.modules["rimport"] = rimport +# Don't add to sys.modules to avoid conflict with other test files loader.exec_module(rimport) diff --git a/tests/rimport/test_stage_data.py b/tests/rimport/test_stage_data.py index d24e769..85c3bb1 100644 --- a/tests/rimport/test_stage_data.py +++ b/tests/rimport/test_stage_data.py @@ -3,13 +3,15 @@ """ import os -import sys +import logging import importlib.util from importlib.machinery import SourceFileLoader from unittest.mock import patch import pytest +import shared + # Import rimport module from file without .py extension rimport_path = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), @@ -20,10 +22,19 @@ if spec is None: raise ImportError(f"Could not create spec for rimport from {rimport_path}") rimport = importlib.util.module_from_spec(spec) -sys.modules["rimport"] = rimport +# Don't add to sys.modules to avoid conflict with other test files loader.exec_module(rimport) +@pytest.fixture(autouse=True) +def configure_logging_for_tests(): + """Configure logging for all tests in this module.""" + shared.configure_logging(logging.INFO) + yield + # Cleanup + rimport.logger.handlers.clear() + + @pytest.fixture(name="inputdata_root") def fixture_inputdata_root(tmp_path): """Create and return an inputdata root directory.""" @@ -57,7 +68,7 @@ def test_copies_file_to_staging(self, inputdata_root, staging_root): assert dst.exists() assert dst.read_text() == "data content" - def test_check_doesnt_copy(self, inputdata_root, staging_root, capsys): + def test_check_doesnt_copy(self, inputdata_root, staging_root, caplog): """Test that a file is NOT copied to the staging directory if check is True""" # Create file in inputdata root src = inputdata_root / "file.nc" @@ -70,8 +81,8 @@ def test_check_doesnt_copy(self, inputdata_root, staging_root, capsys): dst = staging_root / "file.nc" assert not dst.exists() - # Verify message was printed - assert "not already published" in capsys.readouterr().out.strip() + # Verify message was logged + assert "not already published" in caplog.text def test_preserves_directory_structure(self, inputdata_root, staging_root): """Test that directory structure is preserved in staging.""" @@ -89,7 +100,7 @@ def test_preserves_directory_structure(self, inputdata_root, staging_root): assert dst.read_text() == "nested data" def test_prints_live_symlink_already_published_not_downloadable( - self, inputdata_root, staging_root, capsys + self, inputdata_root, staging_root, caplog ): """ Test that staging a live, already-published symlink prints a message and returns @@ -106,22 +117,22 @@ def test_prints_live_symlink_already_published_not_downloadable( # Should print message for live symlink and return early rimport.stage_data(src, inputdata_root, staging_root) - # Verify the right messages were printed - stdout = capsys.readouterr().out.strip() + # Verify the right messages were logged msg = "File is already published and linked" - assert msg in stdout + assert msg in caplog.text msg = "File is not (yet) available for download" - assert msg in stdout + assert msg in caplog.text - # Verify the WRONG message was NOT printed + # Verify the WRONG message was NOT logged msg = "is already under staging directory" - assert msg not in stdout + assert msg not in caplog.text # Verify that shutil.copy2 was never called (function returned early) mock_copy.assert_not_called() + @patch.object(rimport, "can_file_be_downloaded") def test_prints_live_symlink_already_published_is_downloadable( - self, inputdata_root, staging_root, capsys + self, mock_can_file_be_downloaded, inputdata_root, staging_root, caplog ): """ Like test_prints_live_symlink_already_published_not_downloadable, but mocks @@ -133,29 +144,30 @@ def test_prints_live_symlink_already_published_is_downloadable( src = inputdata_root / "link.nc" src.symlink_to(real_file) + # Mock can_file_be_downloaded to return True + mock_can_file_be_downloaded.return_value = True + # Mock shutil.copy2 to verify it's never called with patch("shutil.copy2") as mock_copy: - # Mock can_file_be_downloaded to return True - with patch("rimport.can_file_be_downloaded", return_value=True): - # Should print message for live symlink and return early - rimport.stage_data(src, inputdata_root, staging_root) + # Should print message for live symlink and return early + rimport.stage_data(src, inputdata_root, staging_root) - # Verify that shutil.copy2 was never called (function returned early) - mock_copy.assert_not_called() + # Verify that shutil.copy2 was never called (function returned early) + mock_copy.assert_not_called() - # Verify the right messages were printed - stdout = capsys.readouterr().out.strip() + # Verify the right messages were logged msg = "File is already published and linked" - assert msg in stdout + assert msg in caplog.text msg = "File is available for download" - assert msg in stdout + assert msg in caplog.text - # Verify the WRONG message was NOT printed + # Verify the WRONG message was NOT logged msg = "is already under staging directory" - assert msg not in stdout + assert msg not in caplog.text + @patch.object(rimport, "can_file_be_downloaded") def test_prints_published_but_not_linked( - self, inputdata_root, staging_root, capsys + self, mock_can_file_be_downloaded, inputdata_root, staging_root, caplog ): """ Tests printed message for when a file has been published (copied to staging root) but not @@ -168,24 +180,25 @@ def test_prints_published_but_not_linked( inputdata = inputdata_root / filename inputdata.write_text("data") + # Mock can_file_be_downloaded to return True + mock_can_file_be_downloaded.return_value = True + # Mock shutil.copy2 to verify it's never called with patch("shutil.copy2") as mock_copy: - # Mock can_file_be_downloaded to return True - with patch("rimport.can_file_be_downloaded", return_value=True): - # Should print message for live symlink and return early - rimport.stage_data(inputdata, inputdata_root, staging_root) - # Verify that shutil.copy2 was never called (function returned early) - mock_copy.assert_not_called() + # Should print message for live symlink and return early + rimport.stage_data(inputdata, inputdata_root, staging_root) + + # Verify that shutil.copy2 was never called (function returned early) + mock_copy.assert_not_called() - # Verify the right messages were printed or not - stdout = capsys.readouterr().out.strip() + # Verify the right messages were logged or not msg = "File is already published and linked" - assert msg not in stdout + assert msg not in caplog.text msg = "File is already published but NOT linked; do" - assert msg in stdout + assert msg in caplog.text msg = "File is available for download" - assert msg in stdout + assert msg in caplog.text def test_raises_error_for_live_symlink_pointing_somewhere_other_than_staging( self, tmp_path, inputdata_root, staging_root diff --git a/tests/shared/test_configure_logging.py b/tests/shared/test_configure_logging.py new file mode 100644 index 0000000..918431f --- /dev/null +++ b/tests/shared/test_configure_logging.py @@ -0,0 +1,140 @@ +""" +Tests for shared configure_logging() function. +""" + +import logging + +import pytest + +import shared + +logger = logging.getLogger(__name__) + + +class TestConfigureLogging: + """Test suite for configure_logging() function.""" + + @pytest.fixture(autouse=True) + def cleanup_logger(self): + """Clean up logger handlers after each test.""" + yield + # Clear handlers after test + logger.handlers.clear() + + def test_sets_logger_level_to_info(self): + """Test that configure_logging sets the logger level to INFO.""" + shared.configure_logging(logging.INFO, logger) + assert logger.level == logging.INFO + + def test_creates_two_handlers(self): + """Test that configure_logging creates exactly two handlers.""" + shared.configure_logging(logging.INFO, logger) + assert len(logger.handlers) == 2 + + def test_info_handler_goes_to_stdout(self, capsys): + """Test that INFO level messages go to stdout.""" + shared.configure_logging(logging.INFO, logger) + logger.info("Test info message") + + captured = capsys.readouterr() + assert "Test info message" in captured.out + assert captured.err == "" + + def test_warning_handler_goes_to_stdout(self, capsys): + """Test that WARNING level messages go to stdout.""" + shared.configure_logging(logging.INFO, logger) + logger.warning("Test warning message") + + captured = capsys.readouterr() + assert "Test warning message" in captured.out + assert captured.err == "" + + def test_error_handler_goes_to_stderr(self, capsys): + """Test that ERROR level messages go to stderr.""" + shared.configure_logging(logging.INFO, logger) + logger.error("Test error message") + + captured = capsys.readouterr() + assert captured.out == "" + assert "Test error message" in captured.err + + def test_critical_handler_goes_to_stderr(self, capsys): + """Test that CRITICAL level messages go to stderr.""" + shared.configure_logging(logging.INFO, logger) + logger.critical("Test critical message") + + captured = capsys.readouterr() + assert captured.out == "" + assert "Test critical message" in captured.err + + def test_clears_existing_handlers(self): + """Test that configure_logging clears any existing handlers.""" + # Add a dummy handler + from io import StringIO + + dummy_handler = logging.StreamHandler(StringIO()) + logger.addHandler(dummy_handler) + assert len(logger.handlers) >= 1 + + # Configure logging + shared.configure_logging(logging.INFO, logger) + + # Verify old handlers were cleared and new ones added + assert len(logger.handlers) == 2 + assert dummy_handler not in logger.handlers + + def test_formatter_uses_message_only(self, capsys): + """Test that the formatter outputs only the message without level/timestamp.""" + shared.configure_logging(logging.INFO, logger) + logger.info("Simple message") + + captured = capsys.readouterr() + output = captured.out.strip() + assert output == "Simple message" + assert "INFO" not in output + + def test_multiple_calls_dont_duplicate_handlers(self): + """Test that calling configure_logging multiple times doesn't duplicate handlers.""" + shared.configure_logging(logging.INFO, logger) + assert len(logger.handlers) == 2 + + shared.configure_logging(logging.INFO, logger) + assert len(logger.handlers) == 2 # Still 2, not 4 + + shared.configure_logging(logging.INFO, logger) + assert len(logger.handlers) == 2 # Still 2, not 6 + + def test_configure_with_debug_level(self, capsys): + """Test that configure_logging accepts DEBUG level.""" + shared.configure_logging(logging.DEBUG, logger) + + # DEBUG messages should now be logged + logger.debug("Debug message") + + captured = capsys.readouterr() + assert "Debug message" in captured.out + assert captured.err == "" + + def test_configure_with_warning_level(self, capsys): + """Test that configure_logging accepts WARNING level.""" + shared.configure_logging(logging.WARNING, logger) + + # INFO messages should be suppressed + logger.info("Info message") + # WARNING messages should be logged + logger.warning("Warning message") + + captured = capsys.readouterr() + assert "Info message" not in captured.out + assert "Warning message" in captured.out + assert captured.err == "" + + def test_configure_with_info_level_suppresses_debug(self, capsys): + """Test that INFO level suppresses DEBUG messages.""" + shared.configure_logging(logging.INFO, logger) + + logger.debug("Debug message") + logger.info("Info message") + + captured = capsys.readouterr() + assert "Debug message" not in captured.out diff --git a/tests/shared/test_get_log_level.py b/tests/shared/test_get_log_level.py new file mode 100644 index 0000000..87ee1e4 --- /dev/null +++ b/tests/shared/test_get_log_level.py @@ -0,0 +1,35 @@ +""" +Tests for shared.py get_log_level() function. +""" + +import logging +import shared + + +class TestGetLogLevel: + """Test suite for get_log_level() function.""" + + def test_default_returns_info(self): + """Test that default (no flags) returns INFO level.""" + result = shared.get_log_level() + assert result == logging.INFO + + def test_quiet_returns_warning(self): + """Test that quiet=True returns WARNING level.""" + result = shared.get_log_level(quiet=True) + assert result == logging.WARNING + + def test_verbose_returns_debug(self): + """Test that verbose=True returns DEBUG level.""" + result = shared.get_log_level(verbose=True) + assert result == logging.DEBUG + + def test_quiet_takes_precedence_over_verbose(self): + """Test that quiet takes precedence when both are True.""" + result = shared.get_log_level(quiet=True, verbose=True) + assert result == logging.WARNING + + def test_quiet_false_verbose_false(self): + """Test explicit False values return INFO.""" + result = shared.get_log_level(quiet=False, verbose=False) + assert result == logging.INFO diff --git a/tests/shared/test_validate_paths.py b/tests/shared/test_validate_paths.py new file mode 100644 index 0000000..d7b7a89 --- /dev/null +++ b/tests/shared/test_validate_paths.py @@ -0,0 +1,150 @@ +""" +Tests for shared.py validate_directory() and validate_paths() functions. +""" + +import os +import argparse + +import pytest + +import shared + + +class TestValidateDirectory: + """Test suite for validate_directory function.""" + + def test_valid_directory(self, tmp_path): + """Test that valid directory is accepted and returns absolute path.""" + test_dir = tmp_path / "valid_dir" + test_dir.mkdir() + + result = shared.validate_directory(str(test_dir)) + assert result == str(test_dir.resolve()) + + def test_nonexistent_directory(self): + """Test that nonexistent directory raises ArgumentTypeError.""" + nonexistent = os.path.join(os.sep, "nonexistent", "directory", "12345") + + with pytest.raises(argparse.ArgumentTypeError) as exc_info: + shared.validate_directory(nonexistent) + + assert "does not exist" in str(exc_info.value) + assert nonexistent in str(exc_info.value) + + def test_relative_path_converted_to_absolute(self, tmp_path): + """Test that relative paths are converted to absolute.""" + test_dir = tmp_path / "relative_test" + test_dir.mkdir() + + # Change to parent directory and use relative path + cwd = os.getcwd() + try: + os.chdir(str(tmp_path)) + result = shared.validate_directory("relative_test") + assert os.path.isabs(result) + assert result == str(test_dir.resolve()) + finally: + os.chdir(cwd) + + def test_symlink_to_directory(self, tmp_path): + """Test that symlink to a directory is accepted.""" + real_dir = tmp_path / "real_dir" + real_dir.mkdir() + + link_dir = tmp_path / "link_dir" + link_dir.symlink_to(real_dir) + + result = shared.validate_paths(str(link_dir)) + # validate_directory returns absolute path of the symlink itself + assert result == str(link_dir.absolute()) + # Verify it's still a symlink + assert os.path.islink(result) + + def test_list_with_invalid_directory(self, tmp_path): + """Test that a list with one invalid directory raises error.""" + dir1 = tmp_path / "dir1" + dir1.mkdir() + nonexistent = tmp_path / "nonexistent" + + with pytest.raises(argparse.ArgumentTypeError) as exc_info: + shared.validate_paths([str(dir1), str(nonexistent)]) + + assert "does not exist" in str(exc_info.value) + + +class TestValidatePaths: + """Test suite for validate_paths function.""" + + def test_valid_directory(self, tmp_path): + """Test that valid directory is accepted and returns absolute path.""" + test_dir = tmp_path / "valid_dir" + test_dir.mkdir() + + result = shared.validate_paths(str(test_dir)) + assert result == str(test_dir.resolve()) + + def test_nonexistent_directory(self): + """Test that nonexistent directory raises ArgumentTypeError.""" + nonexistent = os.path.join(os.sep, "nonexistent", "directory", "12345") + + with pytest.raises(argparse.ArgumentTypeError) as exc_info: + shared.validate_paths(nonexistent) + + assert "does not exist" in str(exc_info.value) + assert nonexistent in str(exc_info.value) + + def test_file_instead_of_directory(self, tmp_path): + """Test that a file path doesn't raise ArgumentTypeError (or any error).""" + test_file = tmp_path / "test_file.txt" + test_file.write_text("content") + + shared.validate_paths(str(test_file)) + + def test_relative_path_converted_to_absolute(self, tmp_path): + """Test that relative paths are converted to absolute.""" + test_dir = tmp_path / "relative_test" + test_dir.mkdir() + + # Change to parent directory and use relative path + cwd = os.getcwd() + try: + os.chdir(str(tmp_path)) + result = shared.validate_paths("relative_test") + assert os.path.isabs(result) + assert result == str(test_dir.resolve()) + finally: + os.chdir(cwd) + + def test_symlink_to_directory(self, tmp_path): + """Test that symlink to a directory is accepted.""" + real_dir = tmp_path / "real_dir" + real_dir.mkdir() + + link_dir = tmp_path / "link_dir" + link_dir.symlink_to(real_dir) + + result = shared.validate_paths(str(link_dir)) + # validate_directory returns absolute path of the symlink itself + assert result == str(link_dir.absolute()) + # Verify it's still a symlink + assert os.path.islink(result) + + def test_list_with_invalid_directory(self, tmp_path): + """Test that a list with one invalid directory raises error.""" + dir1 = tmp_path / "dir1" + dir1.mkdir() + nonexistent = tmp_path / "nonexistent" + + with pytest.raises(argparse.ArgumentTypeError) as exc_info: + shared.validate_paths([str(dir1), str(nonexistent)]) + + assert "does not exist" in str(exc_info.value) + + def test_list_with_file_instead_of_directory(self, tmp_path): + """Test that a list containing a file doesn't raise error.""" + dir1 = tmp_path / "dir1" + dir1.mkdir() + file1 = tmp_path / "file.txt" + file1.write_text("content") + + shared.validate_paths([str(dir1), str(file1)])