diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index e3e836ad2..73d09a127 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -31,6 +31,9 @@ TOTAL_LOOPING_TIME_LSP = 10.0 # Kept same timing for LSP mode to avoid in increase in performance reporting N_CANDIDATES_LP_LSP = 3 +# setting this value to 1 will disable repair if there is at least one correct candidate +MIN_CORRECT_CANDIDATES = 2 + # Code repair REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted) MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 547aefb39..469082d53 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -48,6 +48,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT, MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE, MAX_REPAIRS_PER_TRACE, + MIN_CORRECT_CANDIDATES, N_TESTS_TO_GENERATE_EFFECTIVE, REFINE_ALL_THRESHOLD, REFINED_CANDIDATE_RANKING_WEIGHTS, @@ -887,6 +888,7 @@ def process_single_candidate( baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + eval_ctx=eval_ctx, code_context=code_context, candidate=candidate, exp_type=exp_type, @@ -2045,6 +2047,7 @@ def repair_if_possible( self, candidate: OptimizedCandidate, diffs: list[TestDiff], + eval_ctx: CandidateEvaluationContext, code_context: CodeOptimizationContext, test_results_count: int, exp_type: str, @@ -2052,6 +2055,12 @@ def repair_if_possible( if self.repair_counter >= MAX_REPAIRS_PER_TRACE: logger.debug(f"Repair counter reached {MAX_REPAIRS_PER_TRACE}, skipping repair") return + + successful_candidates_count = sum(1 for is_correct in eval_ctx.is_correct.values() if is_correct) + if successful_candidates_count >= MIN_CORRECT_CANDIDATES: + logger.debug(f"{successful_candidates_count} of the candidates were correct, no need to repair") + return + if candidate.source not in (OptimizedCandidateSource.OPTIMIZE, OptimizedCandidateSource.OPTIMIZE_LP): # only repair the first pass of the candidates for now logger.debug(f"Candidate is a result of {candidate.source.value}, skipping repair") @@ -2089,6 +2098,7 @@ def run_optimized_candidate( baseline_results: OriginalCodeBaseline, original_helper_code: dict[Path, str], file_path_to_helper_classes: dict[Path, set[str]], + eval_ctx: CandidateEvaluationContext, code_context: CodeOptimizationContext, candidate: OptimizedCandidate, exp_type: str, @@ -2144,7 +2154,9 @@ def run_optimized_candidate( logger.info("h3|Test results matched ✅") console.rule() else: - self.repair_if_possible(candidate, diffs, code_context, len(candidate_behavior_results), exp_type) + self.repair_if_possible( + candidate, diffs, eval_ctx, code_context, len(candidate_behavior_results), exp_type + ) return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")