diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index a5a5986c9..15379fa23 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -988,9 +988,12 @@ def _infer_type_from_assertion_args(self, original_text: str, method: str) -> st # If the first arg is a string literal, check if there are 3+ args — if so, the real expected # value is the second argument, not the message string. if expected.startswith('"') and method in ("assertEquals", "assertNotEquals"): - all_args = self._split_top_level_args(args_str) - if len(all_args) >= 3: - expected = all_args[1].strip() + # Use a lightweight scan that stops once we confirm there are >=3 top-level args and + # extracts the second argument, avoiding the full split. + second = self._second_arg_if_message(args_str) + if second is not None: + expected = second.strip() + return self._type_from_literal(expected) @@ -1195,41 +1198,144 @@ def _extract_first_arg(self, args_str: str) -> str | None: if i >= n: return None + + start = i depth = 0 in_string = False string_char = "" - cur: list[str] = [] while i < n: ch = args_str[i] if in_string: - cur.append(ch) if ch == "\\" and i + 1 < n: - i += 1 - cur.append(args_str[i]) - elif ch == string_char: + # Skip escaped character + i += 2 + continue + if ch == string_char: in_string = False - elif ch in ('"', "'"): + i += 1 + continue + + if ch in ('"', "'"): in_string = True string_char = ch - cur.append(ch) + i += 1 elif ch in ("(", "<", "[", "{"): depth += 1 - cur.append(ch) + i += 1 elif ch in (")", ">", "]", "}"): depth -= 1 - cur.append(ch) + i += 1 elif ch == "," and depth == 0: + # end just before comma + end = i break else: - cur.append(ch) - i += 1 + i += 1 + else: + # reached end without a top-level comma + end = i # Trim trailing whitespace from the extracted argument - if not cur: + if start >= end: + return None + return args_str[start:end].rstrip() + + def _second_arg_if_message(self, args_str: str) -> str | None: + """If the first top-level arg is a string message and there are >=3 top-level args, + return the second top-level arg. Otherwise return None. + + This performs a short-circuit parse that counts top-level commas and extracts + only as much as needed (stops after finding the second top-level comma). + """ + n = len(args_str) + i = 0 + + # skip leading whitespace + while i < n and args_str[i].isspace(): + i += 1 + if i >= n: + return None + + depth = 0 + in_string = False + string_char = "" + + # Find first top-level comma (end of first arg) + first_comma = -1 + while i < n: + ch = args_str[i] + if in_string: + if ch == "\\" and i + 1 < n: + i += 2 + continue + if ch == string_char: + in_string = False + i += 1 + continue + + if ch in ('"', "'"): + in_string = True + string_char = ch + i += 1 + elif ch in ("(", "<", "[", "{"): + depth += 1 + i += 1 + elif ch in (")", ">", "]", "}"): + depth -= 1 + i += 1 + elif ch == "," and depth == 0: + first_comma = i + i += 1 + break + else: + i += 1 + + if first_comma == -1: + # only one argument return None - return "".join(cur).rstrip() + + # Now find second top-level comma (i currently at char after first comma) + depth = 0 + in_string = False + string_char = "" + second_comma = -1 + while i < n: + ch = args_str[i] + if in_string: + if ch == "\\" and i + 1 < n: + i += 2 + continue + if ch == string_char: + in_string = False + i += 1 + continue + + if ch in ('"', "'"): + in_string = True + string_char = ch + i += 1 + elif ch in ("(", "<", "[", "{"): + depth += 1 + i += 1 + elif ch in (")", ">", "]", "}"): + depth -= 1 + i += 1 + elif ch == "," and depth == 0: + second_comma = i + break + else: + i += 1 + + if second_comma == -1: + # fewer than 3 args + return None + + # Extract second arg between first_comma and second_comma + start = first_comma + 1 + end = second_comma + return args_str[start:end] def transform_java_assertions(