Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
from contextlib import contextmanager
from functools import lru_cache
from importlib.util import find_spec
from pathlib import Path
from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -355,6 +356,56 @@ def module_name_from_file_path(file_path: Path, project_root_path: Path, *, trav
raise ValueError(msg) # noqa: B904


def validate_module_import(module_path: str, project_root: Path) -> tuple[bool, str]:
"""Check if a module is importable using find_spec (no actual import or subprocess).

Returns (success, error_message). Uses importlib.util.find_spec to check
module availability without triggering module initialization.
"""
project_root_str = str(project_root)
added = False
if project_root_str not in sys.path:
sys.path.insert(0, project_root_str)
added = True
try:
if find_spec(module_path) is not None:
return True, ""
return False, f"Module '{module_path}' not found (find_spec returned None)"
except ModuleNotFoundError as e:
return False, str(e)
except Exception as e:
return False, f"Error checking module '{module_path}': {e}"
finally:
if added:
sys.path.remove(project_root_str)


def infer_module_root_from_file(file_path: Path, pyproject_dir: Path) -> Path | None:
"""Infer the correct module-root for a Python file by walking the __init__.py chain.

Walks up from the file's parent directory toward pyproject_dir, tracking the
topmost directory that contains ``__init__.py`` (i.e. the top-level package).
The module-root is this top-level package directory, since
``project_root_from_module_root`` will use its parent as the PYTHONPATH entry.

Returns the inferred module-root path, or None if inference fails.
"""
file_path = file_path.resolve()
pyproject_dir = pyproject_dir.resolve()
current = file_path.parent
top_package: Path | None = None
while current not in (pyproject_dir, current.parent):
if (current / "__init__.py").exists():
top_package = current
else:
break
current = current.parent
if top_package is not None:
return top_package
# No __init__.py found — treat the file's own directory as the module-root
return file_path.parent


def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:
"""Get file path from module path."""
return project_root_path / (module_name.replace(".", os.sep) + ".py")
Expand Down
76 changes: 76 additions & 0 deletions codeflash/languages/python/function_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from __future__ import annotations

import ast
import os
from pathlib import Path
from typing import TYPE_CHECKING

import tomlkit

from codeflash.cli_cmds.cli import project_root_from_module_root
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import (
infer_module_root_from_file,
module_name_from_file_path,
validate_module_import,
)
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.code_utils.config_parser import find_pyproject_toml
from codeflash.either import Failure, Success
from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
Expand Down Expand Up @@ -39,6 +49,72 @@


class PythonFunctionOptimizer(FunctionOptimizer):
def try_correct_module_root(self) -> bool:
"""Try to infer and apply the correct module-root if the current one is wrong.

Walks the __init__.py chain to determine the correct module-root, validates
it by trying an import, and updates pyproject.toml + in-memory config on success.
"""
try:
pyproject_path = find_pyproject_toml(None)
except ValueError:
return False

if self.args is None:
return False

pyproject_dir = pyproject_path.parent
inferred_root = infer_module_root_from_file(self.function_to_optimize.file_path, pyproject_dir)
if inferred_root is None or inferred_root.resolve() == self.args.module_root.resolve():
return False

new_module_root = inferred_root.resolve()
new_project_root = project_root_from_module_root(new_module_root, pyproject_path)
try:
new_module_path = module_name_from_file_path(self.function_to_optimize.file_path, new_project_root)
except ValueError:
return False

import_ok, _ = validate_module_import(new_module_path, new_project_root)
if not import_ok:
return False

# Import succeeded with the inferred module-root — update pyproject.toml
try:
with pyproject_path.open("rb") as f:
data = tomlkit.parse(f.read())
relative_root = os.path.relpath(new_module_root, pyproject_dir)
data["tool"]["codeflash"]["module-root"] = relative_root # type: ignore[index]
with pyproject_path.open("w", encoding="utf-8") as f:
f.write(tomlkit.dumps(data))
except Exception:
logger.debug("Failed to update pyproject.toml with corrected module-root")
return False

# Update in-memory config
self.args.module_root = new_module_root
self.args.project_root = new_project_root
self.project_root = new_project_root.resolve()
self.original_module_path = new_module_path

logger.info(
f"Auto-corrected module-root to '{os.path.relpath(new_module_root, pyproject_dir)}' in pyproject.toml"
)
return True

def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
# Auto-correct module-root if it doesn't match the inferred root from __init__.py chain
self.try_correct_module_root()
# Validate the (possibly corrected) module can actually be imported
import_ok, import_error = validate_module_import(self.original_module_path, self.project_root)
if not import_ok:
return Failure(
f"Cannot import module '{self.original_module_path}': {import_error}\n"
"This prevents test execution. Please check that all dependencies are installed "
"and that 'module-root' is correctly configured in pyproject.toml."
)
return super().can_be_optimized()

def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
from codeflash.languages.python.context import code_context_extractor

Expand Down
Loading
Loading