Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 122 additions & 16 deletions codeflash/languages/java/remove_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,9 +988,12 @@ def _infer_type_from_assertion_args(self, original_text: str, method: str) -> st
# If the first arg is a string literal, check if there are 3+ args — if so, the real expected
# value is the second argument, not the message string.
if expected.startswith('"') and method in ("assertEquals", "assertNotEquals"):
all_args = self._split_top_level_args(args_str)
if len(all_args) >= 3:
expected = all_args[1].strip()
# Use a lightweight scan that stops once we confirm there are >=3 top-level args and
# extracts the second argument, avoiding the full split.
second = self._second_arg_if_message(args_str)
if second is not None:
expected = second.strip()


return self._type_from_literal(expected)

Expand Down Expand Up @@ -1195,41 +1198,144 @@ def _extract_first_arg(self, args_str: str) -> str | None:
if i >= n:
return None


start = i
depth = 0
in_string = False
string_char = ""
cur: list[str] = []

while i < n:
ch = args_str[i]

if in_string:
cur.append(ch)
if ch == "\\" and i + 1 < n:
i += 1
cur.append(args_str[i])
elif ch == string_char:
# Skip escaped character
i += 2
continue
if ch == string_char:
in_string = False
elif ch in ('"', "'"):
i += 1
continue

if ch in ('"', "'"):
in_string = True
string_char = ch
cur.append(ch)
i += 1
elif ch in ("(", "<", "[", "{"):
depth += 1
cur.append(ch)
i += 1
elif ch in (")", ">", "]", "}"):
depth -= 1
cur.append(ch)
i += 1
elif ch == "," and depth == 0:
# end just before comma
end = i
break
else:
cur.append(ch)
i += 1
i += 1
else:
# reached end without a top-level comma
end = i

# Trim trailing whitespace from the extracted argument
if not cur:
if start >= end:
return None
return args_str[start:end].rstrip()

def _second_arg_if_message(self, args_str: str) -> str | None:
"""If the first top-level arg is a string message and there are >=3 top-level args,
return the second top-level arg. Otherwise return None.

This performs a short-circuit parse that counts top-level commas and extracts
only as much as needed (stops after finding the second top-level comma).
"""
n = len(args_str)
i = 0

# skip leading whitespace
while i < n and args_str[i].isspace():
i += 1
if i >= n:
return None

depth = 0
in_string = False
string_char = ""

# Find first top-level comma (end of first arg)
first_comma = -1
while i < n:
ch = args_str[i]
if in_string:
if ch == "\\" and i + 1 < n:
i += 2
continue
if ch == string_char:
in_string = False
i += 1
continue

if ch in ('"', "'"):
in_string = True
string_char = ch
i += 1
elif ch in ("(", "<", "[", "{"):
depth += 1
i += 1
elif ch in (")", ">", "]", "}"):
depth -= 1
i += 1
elif ch == "," and depth == 0:
first_comma = i
i += 1
break
else:
i += 1

if first_comma == -1:
# only one argument
return None
return "".join(cur).rstrip()

# Now find second top-level comma (i currently at char after first comma)
depth = 0
in_string = False
string_char = ""
second_comma = -1
while i < n:
ch = args_str[i]
if in_string:
if ch == "\\" and i + 1 < n:
i += 2
continue
if ch == string_char:
in_string = False
i += 1
continue

if ch in ('"', "'"):
in_string = True
string_char = ch
i += 1
elif ch in ("(", "<", "[", "{"):
depth += 1
i += 1
elif ch in (")", ">", "]", "}"):
depth -= 1
i += 1
elif ch == "," and depth == 0:
second_comma = i
break
else:
i += 1

if second_comma == -1:
# fewer than 3 args
return None

# Extract second arg between first_comma and second_comma
start = first_comma + 1
end = second_comma
return args_str[start:end]


def transform_java_assertions(
Expand Down
Loading