diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 45a64f0fc..c9c5c9c1c 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -10,6 +10,7 @@ import sys from contextlib import contextmanager from functools import lru_cache +from importlib.util import find_spec from pathlib import Path from tempfile import TemporaryDirectory @@ -355,6 +356,56 @@ def module_name_from_file_path(file_path: Path, project_root_path: Path, *, trav raise ValueError(msg) # noqa: B904 +def validate_module_import(module_path: str, project_root: Path) -> tuple[bool, str]: + """Check if a module is importable using find_spec (no actual import or subprocess). + + Returns (success, error_message). Uses importlib.util.find_spec to check + module availability without triggering module initialization. + """ + project_root_str = str(project_root) + added = False + if project_root_str not in sys.path: + sys.path.insert(0, project_root_str) + added = True + try: + if find_spec(module_path) is not None: + return True, "" + return False, f"Module '{module_path}' not found (find_spec returned None)" + except ModuleNotFoundError as e: + return False, str(e) + except Exception as e: + return False, f"Error checking module '{module_path}': {e}" + finally: + if added: + sys.path.remove(project_root_str) + + +def infer_module_root_from_file(file_path: Path, pyproject_dir: Path) -> Path | None: + """Infer the correct module-root for a Python file by walking the __init__.py chain. + + Walks up from the file's parent directory toward pyproject_dir, tracking the + topmost directory that contains ``__init__.py`` (i.e. the top-level package). + The module-root is this top-level package directory, since + ``project_root_from_module_root`` will use its parent as the PYTHONPATH entry. + + Returns the inferred module-root path, or None if inference fails. + """ + file_path = file_path.resolve() + pyproject_dir = pyproject_dir.resolve() + current = file_path.parent + top_package: Path | None = None + while current not in (pyproject_dir, current.parent): + if (current / "__init__.py").exists(): + top_package = current + else: + break + current = current.parent + if top_package is not None: + return top_package + # No __init__.py found — treat the file's own directory as the module-root + return file_path.parent + + def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path: """Get file path from module path.""" return project_root_path / (module_name.replace(".", os.sep) + ".py") diff --git a/codeflash/languages/python/function_optimizer.py b/codeflash/languages/python/function_optimizer.py index 15babc6b6..e87880945 100644 --- a/codeflash/languages/python/function_optimizer.py +++ b/codeflash/languages/python/function_optimizer.py @@ -1,11 +1,21 @@ from __future__ import annotations import ast +import os from pathlib import Path from typing import TYPE_CHECKING +import tomlkit + +from codeflash.cli_cmds.cli import project_root_from_module_root from codeflash.cli_cmds.console import console, logger +from codeflash.code_utils.code_utils import ( + infer_module_root_from_file, + module_name_from_file_path, + validate_module_import, +) from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE +from codeflash.code_utils.config_parser import find_pyproject_toml from codeflash.either import Failure, Success from codeflash.languages.python.context.unused_definition_remover import ( detect_unused_helper_functions, @@ -39,6 +49,72 @@ class PythonFunctionOptimizer(FunctionOptimizer): + def try_correct_module_root(self) -> bool: + """Try to infer and apply the correct module-root if the current one is wrong. + + Walks the __init__.py chain to determine the correct module-root, validates + it by trying an import, and updates pyproject.toml + in-memory config on success. + """ + try: + pyproject_path = find_pyproject_toml(None) + except ValueError: + return False + + if self.args is None: + return False + + pyproject_dir = pyproject_path.parent + inferred_root = infer_module_root_from_file(self.function_to_optimize.file_path, pyproject_dir) + if inferred_root is None or inferred_root.resolve() == self.args.module_root.resolve(): + return False + + new_module_root = inferred_root.resolve() + new_project_root = project_root_from_module_root(new_module_root, pyproject_path) + try: + new_module_path = module_name_from_file_path(self.function_to_optimize.file_path, new_project_root) + except ValueError: + return False + + import_ok, _ = validate_module_import(new_module_path, new_project_root) + if not import_ok: + return False + + # Import succeeded with the inferred module-root — update pyproject.toml + try: + with pyproject_path.open("rb") as f: + data = tomlkit.parse(f.read()) + relative_root = os.path.relpath(new_module_root, pyproject_dir) + data["tool"]["codeflash"]["module-root"] = relative_root # type: ignore[index] + with pyproject_path.open("w", encoding="utf-8") as f: + f.write(tomlkit.dumps(data)) + except Exception: + logger.debug("Failed to update pyproject.toml with corrected module-root") + return False + + # Update in-memory config + self.args.module_root = new_module_root + self.args.project_root = new_project_root + self.project_root = new_project_root.resolve() + self.original_module_path = new_module_path + + logger.info( + f"Auto-corrected module-root to '{os.path.relpath(new_module_root, pyproject_dir)}' in pyproject.toml" + ) + return True + + def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: + # Auto-correct module-root if it doesn't match the inferred root from __init__.py chain + self.try_correct_module_root() + # Validate the (possibly corrected) module can actually be imported + import_ok, import_error = validate_module_import(self.original_module_path, self.project_root) + if not import_ok: + return Failure( + f"Cannot import module '{self.original_module_path}': {import_error}\n" + "This prevents test execution. Please check that all dependencies are installed " + "and that 'module-root' is correctly configured in pyproject.toml." + ) + return super().can_be_optimized() + def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: from codeflash.languages.python.context import code_context_extractor diff --git a/tests/code_utils/test_module_import_validation.py b/tests/code_utils/test_module_import_validation.py new file mode 100644 index 000000000..66f49015a --- /dev/null +++ b/tests/code_utils/test_module_import_validation.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import tomlkit + +from codeflash.code_utils.code_utils import infer_module_root_from_file, validate_module_import + + +class TestValidateModuleImport: + def test_known_stdlib_module(self, tmp_path: Path) -> None: + ok, err = validate_module_import("json", tmp_path) + assert ok is True + assert err == "" + + def test_nonexistent_module(self, tmp_path: Path) -> None: + ok, err = validate_module_import("totally_nonexistent_module_xyz_123", tmp_path) + assert ok is False + assert err != "" + + def test_finds_module_in_project_root(self, tmp_path: Path) -> None: + pkg = tmp_path / "mypkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + (pkg / "utils.py").write_text("x = 1\n", encoding="utf-8") + + ok, err = validate_module_import("mypkg.utils", tmp_path) + assert ok is True + assert err == "" + + def test_project_root_not_left_in_sys_path(self, tmp_path: Path) -> None: + root_str = str(tmp_path) + assert root_str not in sys.path + validate_module_import("nonexistent_mod", tmp_path) + assert root_str not in sys.path + + def test_project_root_preserved_if_already_in_sys_path(self, tmp_path: Path) -> None: + root_str = str(tmp_path) + sys.path.insert(0, root_str) + try: + validate_module_import("json", tmp_path) + assert root_str in sys.path + finally: + sys.path.remove(root_str) + + def test_find_spec_returns_none(self, tmp_path: Path) -> None: + with patch("codeflash.code_utils.code_utils.find_spec", return_value=None) as mock_fs: + ok, err = validate_module_import("some.mod", tmp_path) + mock_fs.assert_called_once_with("some.mod") + assert ok is False + assert "not found" in err + + def test_find_spec_raises_module_not_found(self, tmp_path: Path) -> None: + with patch( + "codeflash.code_utils.code_utils.find_spec", + side_effect=ModuleNotFoundError("No module named 'boom'"), + ): + ok, err = validate_module_import("boom", tmp_path) + assert ok is False + assert "boom" in err + + def test_find_spec_raises_generic_exception(self, tmp_path: Path) -> None: + with patch( + "codeflash.code_utils.code_utils.find_spec", + side_effect=RuntimeError("something broke"), + ): + ok, err = validate_module_import("broken.mod", tmp_path) + assert ok is False + assert "something broke" in err + + def test_sys_path_cleaned_on_exception(self, tmp_path: Path) -> None: + root_str = str(tmp_path) + assert root_str not in sys.path + with patch("codeflash.code_utils.code_utils.find_spec", side_effect=RuntimeError("boom")): + validate_module_import("mod", tmp_path) + assert root_str not in sys.path + + +class TestInferModuleRootFromFile: + def test_single_package(self, tmp_path: Path) -> None: + pkg = tmp_path / "pkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + mod = pkg / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + result = infer_module_root_from_file(mod, tmp_path) + assert result is not None + assert result.resolve() == pkg.resolve() + + def test_nested_package_returns_top_level(self, tmp_path: Path) -> None: + pkg = tmp_path / "pkg" + sub = pkg / "sub" + sub.mkdir(parents=True) + (pkg / "__init__.py").write_text("", encoding="utf-8") + (sub / "__init__.py").write_text("", encoding="utf-8") + mod = sub / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + result = infer_module_root_from_file(mod, tmp_path) + assert result is not None + assert result.resolve() == pkg.resolve() + + def test_deeply_nested_package(self, tmp_path: Path) -> None: + a = tmp_path / "a" + b = a / "b" + c = b / "c" + c.mkdir(parents=True) + for d in (a, b, c): + (d / "__init__.py").write_text("", encoding="utf-8") + mod = c / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + result = infer_module_root_from_file(mod, tmp_path) + assert result is not None + assert result.resolve() == a.resolve() + + def test_no_init_files_returns_parent_dir(self, tmp_path: Path) -> None: + scripts = tmp_path / "scripts" + scripts.mkdir() + mod = scripts / "run.py" + mod.write_text("print('hi')\n", encoding="utf-8") + + result = infer_module_root_from_file(mod, tmp_path) + assert result is not None + assert result.resolve() == scripts.resolve() + + def test_gap_in_init_chain(self, tmp_path: Path) -> None: + outer = tmp_path / "outer" + inner = outer / "inner" + inner.mkdir(parents=True) + (inner / "__init__.py").write_text("", encoding="utf-8") + mod = inner / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + result = infer_module_root_from_file(mod, tmp_path) + assert result is not None + assert result.resolve() == inner.resolve() + + def test_file_directly_in_pyproject_dir(self, tmp_path: Path) -> None: + mod = tmp_path / "standalone.py" + mod.write_text("x = 1\n", encoding="utf-8") + + result = infer_module_root_from_file(mod, tmp_path) + assert result is not None + assert result.resolve() == tmp_path.resolve() + + def test_src_layout(self, tmp_path: Path) -> None: + src = tmp_path / "src" + pkg = src / "pkg" + pkg.mkdir(parents=True) + (pkg / "__init__.py").write_text("", encoding="utf-8") + mod = pkg / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + result = infer_module_root_from_file(mod, tmp_path) + assert result is not None + assert result.resolve() == pkg.resolve() + + +class TestTryCorrectModuleRoot: + def _make_optimizer_stub( + self, + file_path: Path, + module_root: Path, + project_root: Path, + original_module_path: str = "pkg.mod", + ) -> MagicMock: + from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer + + optimizer = MagicMock(spec=PythonFunctionOptimizer) + optimizer.function_to_optimize = MagicMock() + optimizer.function_to_optimize.file_path = file_path + optimizer.args = MagicMock() + optimizer.args.module_root = module_root + optimizer.args.project_root = project_root + optimizer.project_root = project_root + optimizer.original_module_path = original_module_path + return optimizer + + def test_returns_false_when_pyproject_not_found(self, tmp_path: Path) -> None: + from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer + + mod = tmp_path / "pkg" / "mod.py" + mod.parent.mkdir() + mod.write_text("x = 1\n", encoding="utf-8") + optimizer = self._make_optimizer_stub(mod, tmp_path / "pkg", tmp_path) + + with patch( + "codeflash.languages.python.function_optimizer.find_pyproject_toml", + side_effect=ValueError("not found"), + ): + result = PythonFunctionOptimizer.try_correct_module_root(optimizer) + assert result is False + + def test_returns_false_when_inferred_same_as_current(self, tmp_path: Path) -> None: + from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer + + pkg = tmp_path / "pkg" + pkg.mkdir() + (pkg / "__init__.py").write_text("", encoding="utf-8") + mod = pkg / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text('[tool.codeflash]\nmodule-root = "pkg"\n', encoding="utf-8") + + optimizer = self._make_optimizer_stub(mod, pkg, tmp_path) + + with patch( + "codeflash.languages.python.function_optimizer.find_pyproject_toml", + return_value=pyproject, + ): + result = PythonFunctionOptimizer.try_correct_module_root(optimizer) + assert result is False + + def test_corrects_module_root_and_updates_pyproject(self, tmp_path: Path) -> None: + from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer + + pkg = tmp_path / "pkg" + sub = pkg / "sub" + sub.mkdir(parents=True) + (pkg / "__init__.py").write_text("", encoding="utf-8") + (sub / "__init__.py").write_text("", encoding="utf-8") + mod = sub / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text('[tool.codeflash]\nmodule-root = "pkg/sub"\n', encoding="utf-8") + + optimizer = self._make_optimizer_stub( + file_path=mod, + module_root=sub, + project_root=tmp_path, + original_module_path="sub.mod", + ) + + with ( + patch( + "codeflash.languages.python.function_optimizer.find_pyproject_toml", + return_value=pyproject, + ), + patch( + "codeflash.languages.python.function_optimizer.project_root_from_module_root", + return_value=tmp_path, + ), + patch( + "codeflash.languages.python.function_optimizer.module_name_from_file_path", + return_value="pkg.sub.mod", + ), + patch( + "codeflash.languages.python.function_optimizer.validate_module_import", + return_value=(True, ""), + ), + ): + result = PythonFunctionOptimizer.try_correct_module_root(optimizer) + + assert result is True + data = tomlkit.parse(pyproject.read_text(encoding="utf-8")) + assert data["tool"]["codeflash"]["module-root"] == os.path.relpath(pkg.resolve(), tmp_path) + + def test_returns_false_when_import_validation_fails(self, tmp_path: Path) -> None: + from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer + + pkg = tmp_path / "pkg" + sub = pkg / "sub" + sub.mkdir(parents=True) + (pkg / "__init__.py").write_text("", encoding="utf-8") + (sub / "__init__.py").write_text("", encoding="utf-8") + mod = sub / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text('[tool.codeflash]\nmodule-root = "pkg/sub"\n', encoding="utf-8") + + optimizer = self._make_optimizer_stub( + file_path=mod, + module_root=sub, + project_root=tmp_path, + ) + + with ( + patch( + "codeflash.languages.python.function_optimizer.find_pyproject_toml", + return_value=pyproject, + ), + patch( + "codeflash.languages.python.function_optimizer.project_root_from_module_root", + return_value=tmp_path, + ), + patch( + "codeflash.languages.python.function_optimizer.module_name_from_file_path", + return_value="pkg.sub.mod", + ), + patch( + "codeflash.languages.python.function_optimizer.validate_module_import", + return_value=(False, "Module not found"), + ), + ): + result = PythonFunctionOptimizer.try_correct_module_root(optimizer) + + assert result is False + + def test_returns_false_when_module_name_raises(self, tmp_path: Path) -> None: + from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer + + pkg = tmp_path / "pkg" + sub = pkg / "sub" + sub.mkdir(parents=True) + (pkg / "__init__.py").write_text("", encoding="utf-8") + (sub / "__init__.py").write_text("", encoding="utf-8") + mod = sub / "mod.py" + mod.write_text("x = 1\n", encoding="utf-8") + + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text('[tool.codeflash]\nmodule-root = "pkg/sub"\n', encoding="utf-8") + + optimizer = self._make_optimizer_stub( + file_path=mod, + module_root=sub, + project_root=tmp_path, + ) + + with ( + patch( + "codeflash.languages.python.function_optimizer.find_pyproject_toml", + return_value=pyproject, + ), + patch( + "codeflash.languages.python.function_optimizer.project_root_from_module_root", + return_value=tmp_path, + ), + patch( + "codeflash.languages.python.function_optimizer.module_name_from_file_path", + side_effect=ValueError("cannot derive module name"), + ), + ): + result = PythonFunctionOptimizer.try_correct_module_root(optimizer) + + assert result is False diff --git a/tests/test_languages/test_javascript_test_runner.py b/tests/test_languages/test_javascript_test_runner.py index 9773578fb..d4399516e 100644 --- a/tests/test_languages/test_javascript_test_runner.py +++ b/tests/test_languages/test_javascript_test_runner.py @@ -909,13 +909,13 @@ def test_reporter_produces_valid_junit_xml(self): test_script = Path(tmpdir) / "test_reporter.js" test_script.write_text(f""" // Set env vars BEFORE requiring reporter (matches real Jest behavior) -process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file}'; +process.env.JEST_JUNIT_OUTPUT_FILE = '{output_file.as_posix()}'; process.env.JEST_JUNIT_CLASSNAME = '{{filepath}}'; process.env.JEST_JUNIT_SUITE_NAME = '{{filepath}}'; process.env.JEST_JUNIT_ADD_FILE_ATTRIBUTE = 'true'; process.env.JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT = 'true'; -const Reporter = require('{reporter_path}'); +const Reporter = require('{reporter_path.as_posix()}'); // Mock Jest globalConfig const globalConfig = {{ rootDir: '/tmp/project' }};