diff --git a/.tekton/on-pull-request.yaml b/.tekton/on-pull-request.yaml
index 6315138b..877b2d20 100644
--- a/.tekton/on-pull-request.yaml
+++ b/.tekton/on-pull-request.yaml
@@ -160,6 +160,8 @@ spec:
env:
- name: TARGET_BRANCH_NAME
value: "{{target_branch}}"
+ - name: JAVA_MAVEN_DEFAULT_SETTINGS_FILE_PATH
+ value: $(workspaces.source.path)/kustomize/base/settings.xml
image: registry.access.redhat.com/ubi9/python-312:9.6
workingDir: $(workspaces.source.path)
script: |
@@ -184,8 +186,8 @@ spec:
# Install Java
JAVA_ARCH="x64"
- JDK_URL="https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.2%2B9/OpenJDK22U-jdk_${JAVA_ARCH}_linux_hotspot_22.0.2_9.tar.gz"
- JDK_DIR="jdk-22.0.2+9"
+ JDK_URL="https://github.com/adoptium/temurin17-binaries/releases/download/jdk-17.0.18%2B8/OpenJDK17U-jdk_${JAVA_ARCH}_linux_hotspot_17.0.18_8.tar.gz"
+ JDK_DIR="jdk-17.0.18+8"
echo ">> Downloading $JDK_URL"
mkdir -p /tekton/home/jdk
diff --git a/Dockerfile b/Dockerfile
index 96133ca0..f4b7d137 100755
--- a/Dockerfile
+++ b/Dockerfile
@@ -66,8 +66,8 @@ ENV PATH="/opt/nodejs/bin:${PATH}"
RUN node --version && npm --version
# --- Temurin JDK 22 (amd64/x86_64) ---
-ARG JDK_URL="https://github.com/adoptium/temurin22-binaries/releases/download/jdk-22.0.2%2B9/OpenJDK22U-jdk_x64_linux_hotspot_22.0.2_9.tar.gz"
-ARG JDK_DIR="jdk-22.0.2+9"
+ARG JDK_URL="https://github.com/adoptium/temurin17-binaries/releases/download/jdk-17.0.18%2B8/OpenJDK17U-jdk_x64_linux_hotspot_17.0.18_8.tar.gz"
+ARG JDK_DIR="jdk-17.0.18+8"
RUN mkdir -p /opt/jdk \
&& curl -fsSL -o /tmp/jdk.tgz "${JDK_URL}" \
&& tar -C /opt/jdk -xzf /tmp/jdk.tgz \
diff --git a/kustomize/base/exploit-iq-config.yml b/kustomize/base/exploit-iq-config.yml
index ad70c5f3..b9eab81d 100644
--- a/kustomize/base/exploit-iq-config.yml
+++ b/kustomize/base/exploit-iq-config.yml
@@ -67,6 +67,8 @@ functions:
enable_functions_usage_search: true
Function Locator:
_type: package_and_function_locator
+ Function Library Version Finder:
+ _type: calling_function_library_version_finder
Code Semantic Search:
_type: local_vdb_retriever
embedder_name: nim_embedder
@@ -98,6 +100,7 @@ functions:
- Call Chain Analyzer
- Function Caller Finder
- Function Locator
+ - Function Library Version Finder
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/kustomize/base/exploit_iq_service.yaml b/kustomize/base/exploit_iq_service.yaml
index f64ca472..8dcf43e1 100644
--- a/kustomize/base/exploit_iq_service.yaml
+++ b/kustomize/base/exploit_iq_service.yaml
@@ -125,9 +125,13 @@ spec:
value: /exploit-iq-package-cache/go/pkg/mod
- name: ENABLE_MLOPS
value: "true"
+ - name: JAVA_MAVEN_DEFAULT_SETTINGS_FILE_PATH
+ value: /maven-config/settings.xml
volumeMounts:
- name: config
mountPath: /configs
+ - name: maven-settings-config
+ mountPath: /maven-config
- name: cache
mountPath: /exploit-iq-data
- name: package-cache
@@ -139,6 +143,9 @@ spec:
- name: config
configMap:
name: exploit-iq-config
+ - name: maven-settings-config
+ configMap:
+ name: exploit-iq-maven-settings-config
- name: cache
persistentVolumeClaim:
claimName: exploit-iq-data
diff --git a/kustomize/base/kustomization.yaml b/kustomize/base/kustomization.yaml
index d6dbd4a2..30a1cd1e 100644
--- a/kustomize/base/kustomization.yaml
+++ b/kustomize/base/kustomization.yaml
@@ -63,6 +63,9 @@ configMapGenerator:
files:
- excludes.json
- includes.json
+ - name: exploit-iq-maven-settings-config
+ files:
+ - settings.xml
patches:
- path: ips-patch.json
diff --git a/kustomize/base/settings.xml b/kustomize/base/settings.xml
new file mode 100644
index 00000000..0f9299f3
--- /dev/null
+++ b/kustomize/base/settings.xml
@@ -0,0 +1,62 @@
+
+
+
+
+ red-hat
+
+
+ red-hat-ga
+ https://maven.repository.redhat.com/ga
+
+
+ jboss-public-repository-group
+ JBoss Public Maven Repository Group
+ https://repository.jboss.org/nexus/content/groups/public/
+ default
+
+ true
+ never
+
+
+ true
+ never
+
+
+
+
+
+ red-hat-ga
+ https://maven.repository.redhat.com/ga
+
+ true
+
+
+ false
+
+
+
+
+
+
+ red-hat
+
+
+
+
+
+ maven-default-http-blocker
+ Disabled default HTTP blocker
+
+
+ __no_such_repo_id__
+
+
+ https://repo1.maven.org/maven2
+
+
+
+
\ No newline at end of file
diff --git a/kustomize/config-http-openai-local.yml b/kustomize/config-http-openai-local.yml
index 45dabca2..fa5b8976 100644
--- a/kustomize/config-http-openai-local.yml
+++ b/kustomize/config-http-openai-local.yml
@@ -83,6 +83,8 @@ functions:
max_retries: 5
Container Analysis Data:
_type: container_image_analysis_data
+ Function Library Version Finder:
+ _type: calling_function_library_version_finder
cve_agent_executor:
_type: cve_agent_executor
llm_name: cve_agent_executor_llm
@@ -94,6 +96,7 @@ functions:
- Call Chain Analyzer
- Function Caller Finder
- Function Locator
+ - Function Library Version Finder
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/exploit_iq_commons/utils/dep_tree.py b/src/exploit_iq_commons/utils/dep_tree.py
index b84a2051..5de01a52 100644
--- a/src/exploit_iq_commons/utils/dep_tree.py
+++ b/src/exploit_iq_commons/utils/dep_tree.py
@@ -842,40 +842,115 @@ class JavaDependencyTreeBuilder(DependencyTreeBuilder):
def __init__(self, query: str):
self._query = query
+ def __check_file_exists(self, dir_path: str | Path, filename: str) -> bool:
+ """
+ Return True iff `filename` exists as a regular file directly under `dir_path`.
+ """
+ p = Path(dir_path) / filename
+ return p.is_file()
+
def install_dependencies(self, manifest_path: Path):
+ mvn_command = "mvn"
+ settings_path = os.getenv('JAVA_MAVEN_DEFAULT_SETTINGS_FILE_PATH','../../../../kustomize/base/settings.xml')
source_path = "dependencies-sources"
- subprocess.run(["mvn", "dependency:copy-dependencies", "-Dclassifier=sources",
- f"-DoutputDirectory={source_path}"], cwd=manifest_path)
+
+ if self.__check_file_exists(manifest_path, "mvnw"):
+ mvn_command = "./mvnw"
+
+ process_object = subprocess.run([mvn_command, "-s", settings_path, "dependency:copy-dependencies", "-Dclassifier=sources",
+ "-DincludeScope=runtime", f"-DoutputDirectory={manifest_path.resolve()}/{source_path}"], cwd=manifest_path)
+
+ if process_object.returncode > 0:
+ process_object = subprocess.run([mvn_command, "clean", "install",
+ "-DskipTests", "-s", settings_path], cwd=manifest_path)
+ if process_object.returncode > 0:
+ formatted_error_msg = (
+ f"Failed to build project"
+ f"manifest at {manifest_path}, error details => "
+ f"{process_object.stderr}"
+ )
+ raise Exception(formatted_error_msg)
+
+ process_object = subprocess.run([mvn_command, "-s", settings_path,
+ "dependency:copy-dependencies", "-Dclassifier=sources",
+ "-DincludeScope=runtime", f"-DoutputDirectory={manifest_path.resolve()}/{source_path}"], cwd=manifest_path)
+
+ if process_object.returncode > 0:
+ formatted_error_msg = (
+ f"Failed to install dependencies"
+ f"manifest at {manifest_path}, error details => "
+ f"{process_object.stderr}"
+ )
+ raise Exception(formatted_error_msg)
full_source_path = manifest_path / source_path
for jar in full_source_path.glob("*-sources.jar"):
- dest = full_source_path / jar.stem # folder named after jar
+ if jar.stat().st_size > 0:
+ dest = full_source_path / jar.stem # folder named after jar
- if not dest.exists():
- dest.mkdir(exist_ok=True)
- with zipfile.ZipFile(jar, "r") as zf:
- zf.extractall(dest)
+ if not dest.exists():
+ dest.mkdir(exist_ok=True)
+ with zipfile.ZipFile(jar, "r") as zf:
+ zf.extractall(dest)
def build_tree(self, manifest_path: Path) -> dict[str, list[str]]:
+ mvn_command = "mvn"
+ settings_path = os.getenv("JAVA_MAVEN_DEFAULT_SETTINGS_FILE_PATH", "../../../../kustomize/base/settings.xml")
dependency_file = manifest_path / "dependency_tree.txt"
- package_name = self._query.split(',')[0]
+ package_name = self._query.split(",")[0]
+
+ if self.__check_file_exists(manifest_path, "mvnw"):
+ mvn_command = "./mvnw"
if is_maven_gav(package_name):
- with open(dependency_file, "w") as f:
- subprocess.run(["mvn", "dependency:tree",
- f"-Dincludes={add_missing_jar_string(package_name)}",
- "-Dverbose"], cwd=manifest_path, stdout=f, check=True)
+ subprocess.run(
+ [
+ mvn_command,
+ "com.github.ferstl:depgraph-maven-plugin:4.0.3:aggregate",
+ "-s", settings_path,
+ "-DgraphFormat=text",
+ "-DshowGroupIds",
+ "-DshowVersions",
+ "-DshowTypes",
+ "-DoutputDirectory=.",
+ "-DoutputFileName=dependency_tree.txt",
+ "-DclasspathScope=runtime",
+ f"-DtargetIncludes={add_missing_jar_string(package_name)}",
+ ],
+ cwd=manifest_path,
+ check=True,
+ )
else:
- with open(dependency_file, "w") as f:
- subprocess.run(["mvn", "dependency:tree",
- "-Dverbose"], cwd=manifest_path, stdout=f, check=True)
+ subprocess.run(
+ [
+ mvn_command,
+ "com.github.ferstl:depgraph-maven-plugin:4.0.3:aggregate",
+ "-s", settings_path,
+ "-DgraphFormat=text",
+ "-DshowGroupIds",
+ "-DshowVersions",
+ "-DshowTypes",
+ "-DoutputDirectory=.",
+ "-DclasspathScope=runtime",
+ "-DoutputFileName=dependency_tree.txt",
+ ],
+ cwd=manifest_path,
+ check=True,
+ )
with dependency_file.open("r", encoding="utf-8") as f:
lines = f.readlines()
parent, graph = self.__build_upside_down_dependency_graph(lines)
- # Mark the top level for
+ # Mark *all* roots, not just the first one.
+ # A "root" is any node with no parents in the computed parent-chain list.
+ # This preserves old single-root behavior and fixes multi-root / multi-parent trees.
+ roots = [node for node, parents in graph.items() if not parents]
+ for r in roots:
+ graph[r] = [ROOT_LEVEL_SENTINEL]
+
+ # Backward-compatible: keep the old behavior too (harmless if already set above)
graph[parent] = [ROOT_LEVEL_SENTINEL]
return graph
@@ -885,8 +960,7 @@ def __build_upside_down_dependency_graph(
) -> Tuple[str, Dict[str, List[str]]]:
root: str = ""
stack: List[str] = []
- # coord -> set of direct parents (possibly multiple)
- graph_sets: Dict[str, set] = {}
+ graph_sets: Dict[str, set[str]] = {} # coord -> set of direct parents
for line in dependency_lines:
depth, coord = self.__parse_dependency_line(line)
@@ -894,7 +968,8 @@ def __build_upside_down_dependency_graph(
continue
if depth == 0:
- # start (or restart) a root line
+ # depgraph aggregate can emit multiple top-level roots. Keep the first as "root"
+ # for backward compatibility, but still record others as separate roots in graph_sets.
if not root:
root = coord
stack = [coord]
@@ -907,91 +982,128 @@ def __build_upside_down_dependency_graph(
parent = stack[-1] if stack else None
if parent is not None:
- graph_sets.setdefault(coord, set()).add(parent)
+ graph_sets.setdefault(coord, set()).add(parent) # supports multiple direct parents
graph_sets.setdefault(parent, set())
else:
graph_sets.setdefault(coord, set())
stack.append(coord)
- # ---------- second phase: all parents (direct + transitive) without duplicates ----------
-
def build_parent_chain(node: str) -> List[str]:
"""
- For a given coord, return a flat list of *all* parents reachable
- via any path up to the root, with no duplicates.
-
- Order: breadth-first from nearest parents outward.
+ Return a flat list of all parents reachable via any path, no duplicates.
+ Deterministic BFS order: nearest parents outward.
"""
result: List[str] = []
seen: set[str] = set()
- q = deque(graph_sets.get(node, ()))
+ q = deque(sorted(graph_sets.get(node, ())))
while q:
- parent = q.popleft()
- if parent in seen:
+ p = q.popleft()
+ if p in seen:
continue
- seen.add(parent)
- result.append(parent)
+ seen.add(p)
+ result.append(p)
- # enqueue this parent's parents
- for gp in graph_sets.get(parent, ()):
+ for gp in sorted(graph_sets.get(p, ())):
if gp not in seen:
q.append(gp)
return result
- graph: Dict[str, List[str]] = {
- coord: build_parent_chain(coord) for coord in graph_sets.keys()
- }
-
+ graph: Dict[str, List[str]] = {coord: build_parent_chain(coord) for coord in graph_sets.keys()}
return root, graph
def __parse_dependency_line(self, line: str) -> Tuple[Optional[int], Optional[str]]:
- if not line.startswith("[INFO]"):
+ """
+ Parse one dependency line from depgraph's graphFormat=text output.
+
+ Expected depgraph token shape (after indentation/branch prefix):
+ groupId:artifactId:version:type:scope
+ Example from your file:
+ org.apache.activemq:artemis-openwire-protocol:2.28.0:bundle:compile :contentReference[oaicite:3]{index=3}
+
+ We return (depth, "groupId:artifactId:version") and ignore type/scope/optional marker.
+
+ Also tolerates Maven log prefixes like "[INFO] " if they appear.
+ """
+ raw = (line or "").rstrip("\n")
+ if not raw.strip():
+ return None, None
+
+ # If Maven stdout and depgraph output got mixed, you may see mid-line "[INFO]" injection.
+ # Those lines are not safely recoverable as dependency tokens.
+ if "[INFO]" in raw and not raw.lstrip().startswith("[INFO]"):
return None, None
- # Keep indentation blocks; Maven prints exactly one space after "[INFO]"
- s = line[6:]
- if s.startswith(" "):
- s = s[1:]
+ s = raw.lstrip()
- # Skip non-tree lines early
- if (not s
- or s.startswith(("---", "BUILD", "Scanning", "Finished", "Total time"))
- or ":" not in s):
+ # Strip Maven log prefix if present
+ if s.startswith("[INFO]"):
+ s = s[6:].lstrip()
+
+ # Skip headers and build noise
+ if (
+ not s
+ or "Dependency graph:" in s
+ or s.startswith(("---", "BUILD", "Reactor Summary", "Total time", "Finished at", "Scanning"))
+ or s.startswith("[") # other log levels like [WARNING], [ERROR], etc.
+ or ":" not in s
+ ):
return None, None
- # indent blocks ("| " or " ") + optional "+- " or "\- " + rest
- m = re.match(r'^(?P(?:\| | )*)(?P[+\\]-\s)?(?P.+)$', s)
+ # depgraph indentation blocks ("| " or " ") + optional "+- " or "\- " + rest
+ m = re.match(r"^(?P(?:\| | )*)(?P[+\\]-\s)?(?P.+)$", s)
if not m:
return None, None
- # Each indent block is 3 chars; add 1 if a branch token is present
- depth = (len(m.group('indent')) // 3) + (1 if m.group('branch') else 0)
- rest = m.group('rest').strip()
+ depth = (len(m.group("indent")) // 3) + (1 if m.group("branch") else 0)
+ rest = m.group("rest").strip()
# First token up to whitespace or ')', optionally starting with '('
- m2 = re.match(r'^\(?([^\s\)]+)\)?', rest)
+ m2 = re.match(r"^\(?([^\s\)]+)\)?", rest)
if not m2:
return None, None
- token = m2.group(1) # e.g., io.foo:bar:jar:1.2.3:compile
- parts = token.split(':')
+ token = m2.group(1) # e.g. com.google.guava:guava:32.0.1-jre:jar:compile
+ parts = token.split(":")
- # Drop trailing Maven scope if present
- scopes = {'compile', 'runtime', 'test', 'provided', 'system', 'import'}
+ scopes = {"compile", "runtime", "test", "provided", "system", "import"}
if parts and parts[-1] in scopes:
parts = parts[:-1]
- # Expect group:artifact:type:(classifier:)version — return without the type
- if len(parts) >= 4:
- group, artifact = parts[0], parts[1]
- version = parts[-1]
- coord = f"{group}:{artifact}:{version}"
- return depth, coord
+ if len(parts) < 3:
+ return None, None
+
+ group, artifact = parts[0], parts[1]
+
+ # depgraph text format puts version in position 2:
+ # group:artifact:version:type (scope already removed)
+ # We detect that by checking whether the last part is a packaging/type marker.
+ packaging = {"jar", "war", "pom", "bundle", "maven-plugin", "ear", "ejb", "rar", "zip", "test-jar"}
+
+ def looks_like_version(v: str) -> bool:
+ return any(ch.isdigit() for ch in v)
+
+ version: Optional[str] = None
+
+ # depgraph: group:artifact:version:type
+ if len(parts) >= 4 and parts[-1] in packaging and looks_like_version(parts[2]):
+ version = parts[2]
+ # depgraph (rare): group:artifact:version:type:classifier
+ elif len(parts) >= 5 and parts[-2] in packaging and looks_like_version(parts[2]):
+ version = parts[2]
+ else:
+ # Fallback for other Maven-like formats where version is last
+ if looks_like_version(parts[-1]):
+ version = parts[-1]
+ elif looks_like_version(parts[2]):
+ version = parts[2]
+ else:
+ return None, None
- return None, None
+ coord = f"{group}:{artifact}:{version}"
+ return depth, coord
class PythonDependencyTreeBuilder(DependencyTreeBuilder):
diff --git a/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py
index ac6b50f1..defc9498 100644
--- a/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py
+++ b/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py
@@ -1,15 +1,16 @@
import hashlib
import os
import re
+from functools import lru_cache
from typing import Dict, Tuple, List, Optional
from tqdm import tqdm
from langchain_core.documents import Document
from exploit_iq_commons.utils.functions_parsers.lang_functions_parsers import LanguageFunctionsParser
-from exploit_iq_commons.utils.java_utils import extract_jar_name, JAVA_METHOD_PRIM_TYPES, collect_fields_from_types, get_type_name, \
+from exploit_iq_commons.utils.java_utils import extract_jar_name, JAVA_METHOD_PRIM_TYPES, collect_fields_from_types, \
is_java_type, is_java_method, extract_method_name_with_params, find_function, get_target_class_names, \
- strip_java_generics, JAVA_ANNOTATION_SYMBOL
+ strip_java_generics, JAVA_ANNOTATION_SYMBOL, extract_fqcn
from exploit_iq_commons.logging.loggers_factory import LoggingFactory
logger = LoggingFactory.get_agent_logger(f"morpheus.{__name__}")
@@ -25,61 +26,174 @@
}
TRAILING_ARR_RE = re.compile(r'(?:\s*\[\s*\]\s*)+$')
+_ARRAY_SUFFIX_RE = re.compile(r"\s*(\[\])+$")
+_WS_RE = re.compile(r"\s+")
+
+# Supports:
+# /* inner-type: Inner */
+# /* inner-type: com.foo.Outer.Inner */
+# // inner-type: Inner
+_INNER_TYPE_RE = re.compile(
+ r"""
+(?:
+/\*\s*inner-type\s*:\s*(?P.*?)\s*\*/ # block comment marker
+|
+//\s*inner-type\s*:\s*(?P[^\r\n]*) # line comment marker
+)
+""",
+ re.IGNORECASE | re.DOTALL | re.VERBOSE,
+ )
class JavaLanguageFunctionsParser(LanguageFunctionsParser):
def parse_all_type_struct_class_to_fields(
self,
types: list[Document],
- inheritance_map: dict[Tuple[str, str], List[Tuple[str, str]]]
- ) -> dict[tuple, list[tuple[str, str]]]:
+ inheritance_map: dict[Tuple[str, str], List[Tuple[str, str]]],
+ ) -> dict[Tuple[str, str], list[tuple[str, str]]]:
"""
- Returns (TypeSimpleName, doc.metadata['source']) -> [(memberName, memberType), ...]
+ Returns (TypeFqcn, doc.metadata['source']) -> [(memberName, memberType), ...]
+
- Enums are ignored entirely.
- Interfaces: includes only fields (constants).
- Classes/records: includes fields.
- Also includes inherited fields (from superclasses and implemented interfaces) using BFS.
- Uses only the existing `inheritance_map`.
- - Critically, it classifies *upstream* (parents/interfaces) via relative position indices to avoid pulling in descendants.
+ - Classifies *upstream* (parents/interfaces) via relative position indices to avoid pulling in descendants.
+
+ NOTE: `inheritance_map` keys are now FQCNs and may include inner FQCNs.
"""
from collections import deque
from typing import Dict, Tuple, Any, List
- results: Dict[Tuple[str, Any], List[Tuple[str, str]]] = {}
+ results: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
+
+ # ---------------------------
+ # CHANGED: build a fast (source, parsed-name) -> fqcn resolver from inheritance_map keys
+ # ---------------------------
+ def _norm_fqcn(fqcn: str) -> str:
+ # normalize inner separator so suffix matching works regardless of '$' vs '.'
+ return fqcn.replace("$", ".")
+
+ def _strip_package_prefix(fqcn_dot: str) -> str:
+ """
+ Convert "org.foo.Outer.Inner" -> "Outer.Inner" by dropping package tokens.
+ Uses the heuristic: first segment starting with uppercase begins the type chain.
+ """
+ parts = fqcn_dot.split(".")
+ for idx, tok in enumerate(parts):
+ if tok and tok[0].isupper():
+ return ".".join(parts[idx:])
+ return fqcn_dot
+
+ def _put_unique(m: dict[str, str | None], k: str, v: str) -> None:
+ # store unique, mark ambiguous with None
+ prev = m.get(k)
+ if prev is None:
+ # already ambiguous or unset (if unset None isn't distinguishable), handle below:
+ if k in m:
+ return
+ m[k] = v
+ else:
+ if prev != v:
+ m[k] = None
+
+ # Index by file source, because inheritance_map keys include (fqcn, source)
+ suffix_to_fqcn_by_source: dict[str, dict[str, str | None]] = {}
+ simple_to_fqcn_by_source: dict[str, dict[str, str | None]] = {}
+
+ for (fqcn, src_path) in inheritance_map.keys():
+ fqcn_dot = _norm_fqcn(fqcn)
+ no_pkg = _strip_package_prefix(fqcn_dot) # e.g. Outer.Inner
+ simple = no_pkg.rsplit(".", 1)[-1] # e.g. Inner
+
+ suf_map = suffix_to_fqcn_by_source.setdefault(src_path, {})
+ sim_map = simple_to_fqcn_by_source.setdefault(src_path, {})
+
+ _put_unique(suf_map, no_pkg, fqcn) # "Outer.Inner" -> fqcn (may be with '$')
+ _put_unique(sim_map, simple, fqcn) # "Inner" -> fqcn (if unique)
+
+ # also allow direct normalized fqcn dot form as a key for rare parsers returning fqcn-ish names
+ _put_unique(suf_map, fqcn_dot, fqcn)
+
+ def _resolve_fqcn_for_parsed_name(src_path: str, parsed_name: str) -> str | None:
+ """
+ Map `parsed_name` produced by collect_fields_from_types() to the FQCN key used by inheritance_map.
+
+ Fast-paths:
+ - If parsed_name includes dots (likely "Outer.Inner"), resolve via suffix map.
+ - Else resolve via unique simple-name map.
+ """
+ if not parsed_name:
+ return None
+ name = parsed_name.strip()
+ if not name:
+ return None
+ name_dot = name.replace("$", ".") # be tolerant
+
+ suf_map = suffix_to_fqcn_by_source.get(src_path)
+ sim_map = simple_to_fqcn_by_source.get(src_path)
+
+ if "." in name_dot:
+ if suf_map:
+ fq = suf_map.get(name_dot)
+ return fq if fq else None
+ return None
+ if sim_map:
+ fq = sim_map.get(name_dot)
+ return fq if fq else None
+
+ return None
+
+ # ---------------------------
# 1) Collect own fields per document
+ # CHANGED: per_key_fields now keyed by (fqcn, source) not (simple, source)
+ # ---------------------------
per_key_fields: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
for doc in tqdm(types, total=len(types), desc="Parsing class/interface/enum/record documents for members..."):
src = doc.page_content
- source = doc.metadata['source']
+ source = doc.metadata["source"]
if not src:
continue
per_file = collect_fields_from_types(src, include_nested=True, include_anonymous=True)
- for class_name, fields in per_file.items():
- per_key_fields[(class_name, source)] = fields
+ # Map each parsed type name to the inheritance_map fqcn key for this source file
+ for parsed_type_name, fields in per_file.items():
+ fqcn = _resolve_fqcn_for_parsed_name(source, parsed_type_name)
+ if not fqcn:
+ # If we cannot map reliably to an inheritance_map key, skip (prevents wrong merges).
+ continue
+ per_key_fields[(fqcn, source)] = fields
+
+ # ---------------------------
# 2) Precompute: row lists and O(1) index lookups for each key
+ # ---------------------------
rows: Dict[Tuple[str, str], List[Tuple[str, str]]] = inheritance_map
- row_index: Dict[Tuple[str, str], Dict[Tuple[str, str], int]] = {}
- for k, lst in rows.items():
- # stable O(1) index lookup
- row_index[k] = {v: i for i, v in enumerate(lst)}
-
- # 3) Upstream classifier (parents + interfaces) using relative index rule:
- # candidate is upstream of type_key <=> candidate ∈ rows[type_key],
- # type_key ∈ rows[candidate], and
- # row_index[type_key][candidate] < row_index[candidate][type_key]
+ row_index: Dict[Tuple[str, str], Dict[Tuple[str, str], int]] = {
+ type_key: {neighbor: i for i, neighbor in enumerate(neighbors)}
+ for type_key, neighbors in rows.items()
+ }
+
+ # ---------------------------
+ # 3) Upstream classifier with caching (performance)
+ # ---------------------------
+ upstream_cache: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
+
def upstream_neighbors(type_key: Tuple[str, str]) -> List[Tuple[str, str]]:
"""
Return the direct upstream types (superclasses and implemented interfaces) for `type_key`
using the relative-index rule over `rows` and `row_index`.
"""
+ cached = upstream_cache.get(type_key)
+ if cached is not None:
+ return cached
+
type_row = rows.get(type_key)
if not type_row:
+ upstream_cache[type_key] = []
return []
- # Fast lookup: neighbor -> index in type_key's row
type_index_map = row_index.get(type_key, {})
-
upstreams: List[Tuple[str, str]] = []
seen_candidates: set[Tuple[str, str]] = set()
@@ -87,10 +201,8 @@ def upstream_neighbors(type_key: Tuple[str, str]) -> List[Tuple[str, str]]:
if candidate == type_key or candidate in seen_candidates:
continue
- # Index of candidate within type_key's row
idx_candidate_in_type = type_index_map.get(candidate, 1 << 30)
- # Index of type_key within candidate's row
candidate_index_map = row_index.get(candidate)
if not candidate_index_map:
continue
@@ -98,69 +210,76 @@ def upstream_neighbors(type_key: Tuple[str, str]) -> List[Tuple[str, str]]:
if idx_type_in_candidate is None:
continue
- # Ancestor/interface test via relative positions
if idx_candidate_in_type < idx_type_in_candidate:
upstreams.append(candidate)
seen_candidates.add(candidate)
+ upstream_cache[type_key] = upstreams
return upstreams
+ # ---------------------------
# 4) Merge own + inherited fields per type using BFS over upstream neighbors.
+ # ---------------------------
for type_key, own_fields in per_key_fields.items():
merged: List[Tuple[str, str]] = []
have_names: set[str] = set() # child-over-parent shadowing
- # Own first (preserve order)
- for (fname, ftype) in own_fields:
+ for fname, ftype in own_fields:
if fname not in have_names:
merged.append((fname, ftype))
have_names.add(fname)
- # BFS over upstreams
q: deque[Tuple[str, str]] = deque(upstream_neighbors(type_key))
visited: set[Tuple[str, str]] = set()
while q:
- u = q.popleft()
- if u in visited:
+ upstream_key = q.popleft()
+ if upstream_key in visited:
continue
- visited.add(u)
+ visited.add(upstream_key)
- upstream_fields = per_key_fields.get(u)
+ upstream_fields = per_key_fields.get(upstream_key)
if upstream_fields:
- for (fname, ftype) in upstream_fields:
+ for fname, ftype in upstream_fields:
if fname not in have_names:
merged.append((fname, ftype))
have_names.add(fname)
- for next_upstream in upstream_neighbors(u):
+ for next_upstream in upstream_neighbors(upstream_key):
if next_upstream not in visited:
q.append(next_upstream)
results[type_key] = merged
+ # ---------------------------
# 5) Include types that had no own fields but may inherit some
+ # ---------------------------
for type_key in rows.keys():
if type_key in results:
continue
merged: List[Tuple[str, str]] = []
have_names: set[str] = set()
+
q: deque[Tuple[str, str]] = deque(upstream_neighbors(type_key))
visited: set[Tuple[str, str]] = set()
+
while q:
- u = q.popleft()
- if u in visited:
+ upstream_key = q.popleft()
+ if upstream_key in visited:
continue
- visited.add(u)
- upstream_fields = per_key_fields.get(u)
+ visited.add(upstream_key)
+
+ upstream_fields = per_key_fields.get(upstream_key)
if upstream_fields:
- for (fname, ftype) in upstream_fields:
+ for fname, ftype in upstream_fields:
if fname not in have_names:
merged.append((fname, ftype))
have_names.add(fname)
- for next_upstream in upstream_neighbors(u):
+
+ for next_upstream in upstream_neighbors(upstream_key):
if next_upstream not in visited:
q.append(next_upstream)
+
if merged:
results[type_key] = merged
@@ -169,17 +288,43 @@ def upstream_neighbors(type_key: Tuple[str, str]) -> List[Tuple[str, str]]:
def get_dummy_function(self, function_name):
return f"public void {function_name + '()' + '{}'}"
- def get_class_name_from_class_function(self, func: Document):
+ def __get_optional_inner_type(self, method_source: str) -> Optional[str]:
"""
- Extract the final path segment and strip its last extension.
- Works with / or \ path separators. If there's no dot, returns the segment as-is.
+ Extract the optional inner type marker from a method's extracted source.
+ Returns:
+ - the inner type string (trimmed) if a marker exists
+ - None otherwise
"""
- source = func.metadata.get('source')
+ m = _INNER_TYPE_RE.search(method_source)
+ if not m:
+ return None
+
+ inner = m.group("block") if m.group("block") is not None else m.group("line")
+ inner = (inner or "").strip()
+ return inner or None
+
+ def get_class_name_from_class_function(self, func: Document) -> str:
+ """
+ Return the Java FQCN (fully-qualified class name) inferred from func.metadata['source'].
+
+ Assumptions (per your note):
+ - Either:
+ dependencies-sources/--sources//.java
+ or:
+ .../src/main/java//.java
+ - Package path is the directory structure after the marker, followed by the class filename.
+ """
+ source = (func.metadata.get("source") or "").strip()
+ if not source:
+ return ""
- # get last non-empty segment
- tail = re.split(r"[\\/]", source.rstrip("/\\"))[-1]
- # drop final extension if present
- return tail.rsplit(".", 1)[0]
+ fqcn = extract_fqcn(source)
+
+ inner_type = self.__get_optional_inner_type(func.page_content)
+ if inner_type:
+ return fqcn + "." + inner_type
+
+ return fqcn
def get_function_reserved_word(self) -> str:
return ""
@@ -619,11 +764,6 @@ def search_for_called_function(
- Constructor calls: `new Type(...)` (when the callee is a constructor).
- Method references: `left-hand-side::method` → synthesized as `left-hand-side.method(` for resolution.
- Constructor references: `left-hand-side::new` → synthesized as `new left-hand-side(` for resolution.
-
- Fast-paths:
- - Regex pre-filters for method names and constructor sites.
- - Bounded left-scan to slice the owning expression.
- - De-duplicate by (start, end) slice of the expression being resolved.
"""
def _find_matching_paren(s: str, open_idx: int) -> int:
"""Find matching ')' for the '(' at open_idx. Ignores strings/char literals."""
@@ -680,9 +820,6 @@ def _next_token_after(s: str, idx: int) -> str:
return s[i] if i < n else ''
# Precise left-scan to the start of the expression that owns this call.
- # Stops at the first top-level boundary (not inside (),[],{} or strings):
- # ; , = + - * / % ! ~ ? : & | ^ < > { } \n
- # We allow dots, identifiers, casts, and 'new ...' as part of the expression.
BOUNDARY_CHARS = set(';,=+-*/%!?::&|^<>{}\n')
def _expr_start_left(s: str, pos: int, max_back: int = 512) -> int:
@@ -751,18 +888,14 @@ def _expr_start_left(s: str, pos: int, max_back: int = 512) -> int:
i -= 1
return limit
- # Method-ref left-hand-side extractor: find the minimal "left-hand-side" immediately before '::'
+ # Method-ref left-hand-side extractor: find minimal "lhs" immediately before '::'
def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
"""
- Return the start index of the *left-hand-side* expression immediately preceding '::' at dc_idx.
-
- Unlike _expr_start_left (which finds the start of the whole owning expression),
- this stops at top-level delimiters that commonly precede a method reference inside
- an argument list (notably '(' and ',').
+ Return the start index of the LHS expression immediately preceding '::' at dc_idx.
+ Stops at top-level delimiters that commonly precede a method reference in args
+ (notably '(' and ',') so we don't capture an outer call prefix.
"""
i = dc_idx - 1
- n = len(s)
- # skip whitespace
while i >= 0 and s[i].isspace():
i -= 1
@@ -771,8 +904,6 @@ def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
dp = db = dbr = da = 0 # (), [], {}, <> (generics)
limit = max(0, dc_idx - max_back)
- # boundaries that delimit an argument/expression at top-level for method references
- # (include '(' and ',' to avoid capturing outer call prefixes like "foo(")
REF_BOUNDARY = set(';,=+-*/%!?&|^:\n,(')
while i >= limit:
@@ -806,7 +937,6 @@ def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
i -= 1
continue
- # nesting tracking (reverse)
if ch == ')':
dp += 1
i -= 1
@@ -816,7 +946,6 @@ def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
dp -= 1
i -= 1
continue
- # top-level '(' is a delimiter for the left-hand-side in method refs
if db == dbr == da == 0:
return i + 1
i -= 1
@@ -859,7 +988,6 @@ def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
da -= 1
i -= 1
continue
- # top-level '<' (e.g., comparison) is a delimiter for our left-hand-side extraction
if dp == db == dbr == 0:
return i + 1
i -= 1
@@ -872,7 +1000,9 @@ def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
return limit
- # Extract method body once
+ # ---------------------------
+ # Extract caller function body
+ # ---------------------------
src = caller_function.page_content
try:
lo = src.index("{")
@@ -881,20 +1011,59 @@ def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
except ValueError:
caller_function_body = src
- # --- patterns ---
+ # ---------------------------
+ # Patterns
+ # ---------------------------
method_pat = re.compile(
rf'(?:\breturn\b\s*\(?[^;]*?)?(?:[\w$()\[\].]*?\.?)?\b{re.escape(callee_function_name)}\s*\(',
- re.MULTILINE
+ re.MULTILINE,
)
- declaring_simple = self.get_class_name_from_class_function(callee_function)
+ callee_function_source = callee_function.metadata['source']
+
+ # CHANGED: get_class_name_from_class_function now returns FQCN (possibly inner).
+ # Use it as the declaring FQCN for inheritance lookup AND for constructor name derivation.
+ declaring_fqcn = self.get_class_name_from_class_function(callee_function)
+
+ # CHANGED: normalize inner separator to Java syntax ('.') for matching source text.
+ declaring_fqcn_dot = declaring_fqcn.replace('$', '.')
+ declaring_simple = declaring_fqcn_dot.rsplit('.', 1)[-1]
- # Treat as constructor-target if caller asked the declaring type name
- is_ctor_target = callee_function_name == declaring_simple
+ # CHANGED: support ctor targeting by simple name, no-package inner name, or fqcn.
+ def _strip_package_prefix(fqcn_dot: str) -> str:
+ parts = fqcn_dot.split('.')
+ for idx, tok in enumerate(parts):
+ if tok and tok[0].isupper():
+ return '.'.join(parts[idx:])
+ return fqcn_dot # fallback
- ctor_pat = re.compile(
- rf'\bnew\s+((?:[A-Za-z_$][\w$]*\.)*{re.escape(declaring_simple)})\s*\('
- ) if is_ctor_target else None
+ declaring_no_pkg = _strip_package_prefix(declaring_fqcn_dot)
+
+ # Treat as constructor-target if caller asked for the declaring type name (simple/no-pkg/fqcn).
+ is_ctor_target = callee_function_name in {
+ declaring_simple,
+ declaring_no_pkg,
+ declaring_fqcn_dot,
+ declaring_fqcn, # just in case callers still use '$' form
+ }
+
+ # CHANGED: ctor_pat must work when declaring_simple is derived from an FQCN/inner FQCN.
+ # We match "new <...>(..." where <...> ends with:
+ # - Simple class name (Inner)
+ # - No-package inner chain (Outer.Inner)
+ # - Fully-qualified name (pkg.Outer.Inner)
+ ctor_pat = None
+ if is_ctor_target:
+ ctor_variants = []
+ # Prefer longest first to reduce backtracking.
+ for v in (declaring_fqcn_dot, declaring_no_pkg, declaring_simple):
+ if v and v not in ctor_variants:
+ ctor_variants.append(v)
+ ctor_variants.sort(key=len, reverse=True)
+ ctor_alt = "|".join(re.escape(v) for v in ctor_variants)
+ ctor_pat = re.compile(
+ rf'\bnew\s+((?:[A-Za-z_$][\w$]*\.)*(?:{ctor_alt}))\s*\('
+ )
# Method reference patterns (allow optional method type args: Type::m)
methodref_pat = re.compile(
@@ -904,13 +1073,19 @@ def _method_ref_lhs_start(s: str, dc_idx: int, max_back: int = 512) -> int:
seen_slices: set[tuple[int, int]] = set()
- # Target class names for the resolver
- key = (declaring_simple, callee_function.metadata['source'])
+ # ---------------------------
+ # Target class names (inheritance)
+ # ---------------------------
+ # CHANGED: key uses declaring_fqcn (inner-aware), not extract_fqcn(file_path).
+ key = (declaring_fqcn, callee_function_source)
if "dummy" not in callee_function_file_name:
target_class_names = get_target_class_names(type_inheritance[key])
else:
- target_class_names = frozenset([declaring_simple])
+ target_class_names = frozenset([declaring_fqcn])
+ # ---------------------------
+ # Helpers
+ # ---------------------------
def _process_call(start_idx: int, open_paren_pos: int) -> bool:
close_paren_pos = _find_matching_paren(caller_function_body, open_paren_pos)
if close_paren_pos == -1:
@@ -938,6 +1113,7 @@ def _process_call(start_idx: int, open_paren_pos: int) -> bool:
documents_of_functions=documents_of_functions,
callee_function_name=callee_function_name,
type_inheritance=type_inheritance,
+ callee_declaring_fqcn=declaring_fqcn,
):
logger.debug(
"__check_identifier_resolved_to_callee_function_package resolved successfully - "
@@ -985,16 +1161,21 @@ def _process_method_ref(dc_idx: int, ref_len: int, make_ctor: bool) -> bool:
documents_of_functions=documents_of_functions,
callee_function_name=callee_function_name,
type_inheritance=type_inheritance,
+ callee_declaring_fqcn=declaring_fqcn,
)
+ # ---------------------------
# 1) Constructor matches (only when target is a ctor)
+ # ---------------------------
if is_ctor_target and ctor_pat:
for m in ctor_pat.finditer(caller_function_body):
open_paren_pos = m.end(0) - 1
if _process_call(m.start(), open_paren_pos):
return True
else:
+ # ---------------------------
# 2) Regular method matches
+ # ---------------------------
for m in method_pat.finditer(caller_function_body):
open_paren_pos = m.end(0) - 1
close_paren_pos = _find_matching_paren(caller_function_body, open_paren_pos)
@@ -1008,7 +1189,9 @@ def _process_method_ref(dc_idx: int, ref_len: int, make_ctor: bool) -> bool:
if _process_call(m.start(), open_paren_pos):
return True
+ # ---------------------------
# 3) Method reference matches
+ # ---------------------------
for m in methodref_pat.finditer(caller_function_body):
if _process_method_ref(m.start(), m.end() - m.start(), make_ctor=False):
return True
@@ -1020,7 +1203,6 @@ def _process_method_ref(dc_idx: int, ref_len: int, make_ctor: bool) -> bool:
return False
-
def __check_identifier_resolved_to_callee_function_package(
self,
function: Document,
@@ -1034,7 +1216,8 @@ def __check_identifier_resolved_to_callee_function_package(
target_class_names: frozenset[str],
documents_of_functions: list[Document],
callee_function_name: str,
- type_inheritance: dict[Tuple[str, str], List[Tuple[str, str]]]
+ type_inheritance: dict[Tuple[str, str], List[Tuple[str, str]]],
+ callee_declaring_fqcn: str = "",
) -> bool:
"""
Decide if the found call expression (`identifier_function`) in `function` actually targets
@@ -1048,6 +1231,9 @@ def __check_identifier_resolved_to_callee_function_package(
- Unqualified helper call in same class: helper(…) → resolve helper() return type.
- Receiver that is a cast: ((Type) x).target(…) → use Type as a strong hint.
- SUPER handling uses `target_class_names` (parents/children set) — no file header parsing.
+
+ NOTE (CHANGED): `target_class_names` now contains FQCNs (not simple names), so all calls into
+ `_type_token_matches_callee(...)` are now made with FQCN candidates only.
"""
def _strip_return_and_ws(s: str) -> str:
s = s.strip()
@@ -1071,24 +1257,37 @@ def _split_top_level_dots(expr: str) -> list[str]:
ch = expr[i]
if in_str:
buf.append(ch)
- if ch == '\\' and i + 1 < n: buf.append(expr[i+1]); i += 1
- elif ch == '"': in_str = False
+ if ch == '\\' and i + 1 < n:
+ buf.append(expr[i + 1]); i += 1
+ elif ch == '"':
+ in_str = False
i += 1; continue
if in_chr:
buf.append(ch)
- if ch == '\\' and i + 1 < n: buf.append(expr[i+1]); i += 1
- elif ch == "'": in_chr = False
+ if ch == '\\' and i + 1 < n:
+ buf.append(expr[i + 1]); i += 1
+ elif ch == "'":
+ in_chr = False
i += 1; continue
- if ch == '"': in_str = True; buf.append(ch); i += 1; continue
- if ch == "'": in_chr = True; buf.append(ch); i += 1; continue
- if ch == '(': dp += 1; buf.append(ch); i += 1; continue
- if ch == ')': dp = max(0, dp - 1); buf.append(ch); i += 1; continue
- if ch == '[': db += 1; buf.append(ch); i += 1; continue
- if ch == ']': db = max(0, db - 1); buf.append(ch); i += 1; continue
- if ch == '{': dbr += 1; buf.append(ch); i += 1; continue
- if ch == '}': dbr = max(0, dbr - 1); buf.append(ch); i += 1; continue
- if ch == '<': da += 1; buf.append(ch); i += 1; continue
+ if ch == '"':
+ in_str = True; buf.append(ch); i += 1; continue
+ if ch == "'":
+ in_chr = True; buf.append(ch); i += 1; continue
+ if ch == '(':
+ dp += 1; buf.append(ch); i += 1; continue
+ if ch == ')':
+ dp = max(0, dp - 1); buf.append(ch); i += 1; continue
+ if ch == '[':
+ db += 1; buf.append(ch); i += 1; continue
+ if ch == ']':
+ db = max(0, db - 1); buf.append(ch); i += 1; continue
+ if ch == '{':
+ dbr += 1; buf.append(ch); i += 1; continue
+ if ch == '}':
+ dbr = max(0, dbr - 1); buf.append(ch); i += 1; continue
+ if ch == '<':
+ da += 1; buf.append(ch); i += 1; continue
if ch == '>':
if da > 0: da -= 1
buf.append(ch); i += 1; continue
@@ -1137,7 +1336,6 @@ def _is_fqcn_like(token: str) -> bool:
return bool(pkg) and _is_upper_camel(cls[:1] + cls[1:])
return _is_upper_camel(token)
- # --- resolve the return type of a method declared in the same source file ---
def _find_method_return_type_in_file(source_text: str, method_name: str) -> str | None:
sig = re.compile(
rf"""
@@ -1158,13 +1356,12 @@ def _find_method_return_type_in_file(source_text: str, method_name: str) -> str
ret = re.sub(r'\s*(\[\])+$', '', ret)
return ret or None
- # peel a leading Java cast "((Type) expr)" → (cast_type | None, remainder_expr)
_CAST_RE = re.compile(
r'^\(\s*([A-Za-z_$][\w$]*(?:\.[A-Za-z_$][\w$]*)*(?:<[^()<>]*>)?(?:\s*(?:\[\]))*)\s*\)\s*(.*)\Z'
)
+
def _peel_leading_cast(expr: str) -> tuple[str | None, str]:
s = expr.strip()
- # unwrap outermost parens only if they wrap the whole expr (not a call)
changed = True
while changed and s.startswith('(') and s.endswith(')'):
changed = False
@@ -1179,7 +1376,6 @@ def _peel_leading_cast(expr: str) -> tuple[str | None, str]:
else:
s = s[1:-1].strip()
changed = True
- # peel one or more casts
last_type = None
rest = s
while True:
@@ -1190,7 +1386,6 @@ def _peel_leading_cast(expr: str) -> tuple[str | None, str]:
rest = m.group(2).strip()
return last_type, rest
- # SAFER: extract receiver + method; never builds an empty list or raises IndexError
def _extract_recv_and_method(snippet: str) -> tuple[str | None, str | None]:
s = snippet.strip()
m = re.search(r'([A-Za-z_$][\w$]*)\s*\($', s)
@@ -1205,13 +1400,12 @@ def _extract_recv_and_method(snippet: str) -> tuple[str | None, str | None]:
method = m.group(1)
method_start = m.start(1)
- # immediate '.' before the method token?
i = method_start - 1
- while i >= 0 and s[i].isspace(): i -= 1
+ while i >= 0 and s[i].isspace():
+ i -= 1
if i < 0 or s[i] != '.':
return None, method
- # walk left to get the *smallest* receiver just before the dot
dot_idx = i
def _match_paren_reverse(text: str, close_idx: int) -> int:
@@ -1242,10 +1436,10 @@ def _match_paren_reverse(text: str, close_idx: int) -> int:
return -1
j = dot_idx - 1
- while j >= 0 and s[j].isspace(): j -= 1
+ while j >= 0 and s[j].isspace():
+ j -= 1
if j >= 0 and s[j] == ')':
- # get the minimal '( ... )' group immediately before the dot
open_idx = _match_paren_reverse(s, j)
if open_idx == -1:
start = j
@@ -1254,16 +1448,15 @@ def _match_paren_reverse(text: str, close_idx: int) -> int:
recv = s[start + 1:dot_idx].strip()
return (recv if recv else None), method
- # If the '(' belongs to a *call* (e.g. foo(...)) — include the callee token to its left.
k = open_idx - 1
- while k >= 0 and s[k].isspace(): k -= 1
+ while k >= 0 and s[k].isspace():
+ k -= 1
if k >= 0 and (s[k].isalnum() or s[k] in '_$.'):
start = k
while start >= 0 and (s[start].isalnum() or s[start] in '_$.'):
start -= 1
recv = s[start + 1:dot_idx].strip()
else:
- # parenthesized group/cast → keep minimal group
recv = s[open_idx:dot_idx].strip()
else:
start = j
@@ -1275,51 +1468,161 @@ def _match_paren_reverse(text: str, close_idx: int) -> int:
caller_src = function.metadata.get('source')
caller_doc = code_documents.get(caller_src)
- caller_text = caller_doc.page_content if caller_doc else ""
- caller_class_name = self.get_class_name_from_class_function(caller_doc) if caller_doc else ""
- # Does a type token (simple or qualified) belong to the *callee type(s)/package*?
- def _type_token_matches_callee(type_token: str) -> bool:
- if not type_token:
+ caller_text = (caller_doc.page_content if caller_doc else "") # safe fallback
+ caller_fqcn = self.get_class_name_from_class_function(function)
+
+ # -------------------- FQCN-only type matching --------------------
+ # Build quick index: simple-name -> list of target FQCNs (for fast disambiguation)
+ _target_by_simple: dict[str, list[str]] = {}
+ for fq in target_class_names:
+ last = fq.rsplit('.', 1)[-1]
+ _target_by_simple.setdefault(last, []).append(fq)
+ if '$' in last:
+ _target_by_simple.setdefault(last.split('$')[-1], []).append(fq)
+
+ # Parse header imports once (bounded), used only for disambiguation (still cheap)
+ _hdr = caller_text
+ _PKG_RE = re.compile(r'^\s*package\s+([\w.]+)\s*;\s*$', re.MULTILINE)
+ _IMP_RE = re.compile(r'^\s*import\s+(?:static\s+)?([\w$.]+)\s*;\s*$', re.MULTILINE)
+
+ m_pkg = _PKG_RE.search(_hdr)
+ _caller_pkg = (m_pkg.group(1) if m_pkg else "")
+ _explicit_imports: dict[str, str] = {} # simple -> fqcn (CHANGED)
+ _wild_import_pkgs: list[str] = [] # "a.b" for "import a.b.*" (CHANGED)
+ for imp in _IMP_RE.findall(_hdr):
+ if imp.endswith(".*"):
+ _wild_import_pkgs.append(imp[:-2])
+ else:
+ _explicit_imports[imp.rsplit('.', 1)[-1]] = imp
+
+ def _strip_type_syntax(token: str) -> str:
+ """Strip generics, arrays, and wildcard bounds from a type-ish token (CHANGED)."""
+ t = (token or "").strip()
+ if not t:
+ return ""
+ # Drop keywords after the type (e.g. "? extends Foo")
+ for kw in ("extends", "super"):
+ if kw in t:
+ t = t.split(kw, 1)[0].strip()
+ t = t.replace("?", "").strip()
+
+ # Strip generics with nesting
+ out = []
+ depth = 0
+ for ch in t:
+ if ch == '<':
+ depth += 1
+ continue
+ if ch == '>':
+ if depth > 0:
+ depth -= 1
+ continue
+ if depth == 0:
+ out.append(ch)
+ t = ''.join(out).strip()
+
+ # Strip trailing arrays
+ while t.endswith("]") and "[]" in t:
+ t = re.sub(r'\s*(\[\])\s*$', '', t).strip()
+
+ return t
+
+ def _is_package_qualified(name: str) -> bool:
+ """Heuristic: package-qualified names typically start with a lower-case segment (CHANGED)."""
+ first = name.split('.', 1)[0]
+ return bool(first) and first[0].islower()
+
+ def _iter_fqcn_candidates(raw_type_token: str):
+ """
+ Yield FQCN candidates for a raw type token. Ensures that downstream
+ `_type_token_matches_callee(...)` is invoked with FQCNs only (CHANGED).
+ """
+ t = _strip_type_syntax(raw_type_token)
+ if not t:
+ return
+ # Already looks like a FQCN: pkg.Type or pkg.Outer.Inner
+ if '.' in t and _is_package_qualified(t):
+ yield t
+ return
+
+ # Nested type without package: Outer.Inner -> prefix caller package if known
+ if '.' in t and not _is_package_qualified(t):
+ if _caller_pkg:
+ yield f"{_caller_pkg}.{t}"
+ else:
+ # best-effort; still a "fqcn-ish" candidate
+ yield t
+ return
+
+ # Simple token: try target allow-list by simple name
+ cands = _target_by_simple.get(t)
+ if cands:
+ for fq in cands:
+ yield fq
+
+ # Explicit import disambiguation
+ imp = _explicit_imports.get(t)
+ if imp:
+ yield imp
+
+ # Wildcard imports: only usable if we can cheaply construct candidates
+ # (we only build candidates that are already in target_class_names to avoid work)
+ if _wild_import_pkgs and cands:
+ # already covered by `cands`; do nothing
+ pass
+
+ # Same-package fallback
+ if _caller_pkg:
+ yield f"{_caller_pkg}.{t}"
+
+ def _type_token_matches_callee(type_fqcn: str) -> bool:
+ """
+ (CHANGED) This function now assumes it receives a FQCN token only.
+ All callers must pass FQCNs (or FQCN candidates).
+ """
+ if not type_fqcn:
return False
- tt = re.sub(r'<[^<>]*>', '', type_token).strip()
- tt = re.sub(r'\s*(\[\])+$', '', tt)
- simple_tt = tt.rsplit('.', 1)[-1]
- # IMPORTANT: If allow-list exists and the token's simple name is not in it, short-circuit False.
- if target_class_names and strip_java_generics(simple_tt) not in target_class_names:
+
+ fq = _strip_type_syntax(type_fqcn) # defensive: keep it cheap
+ if not fq:
+ return False
+
+ # IMPORTANT (CHANGED): target_class_names now contains FQCNs; membership is exact.
+ if target_class_names and fq not in target_class_names:
return False
- code_text = caller_text
- if '.' in tt and _is_fqcn_like(tt):
- if self.is_package_imported(code_text, tt):
+
+ if _is_fqcn_like(fq):
+ if self.is_package_imported(caller_text, fq):
return True
+
matches = self.__get_type_docs_matched_with_callee_type(
- callee_package, tt, type_documents, target_class_names
+ callee_package, fq, type_documents, target_class_names
)
return len(matches) > 0
+ def _type_matches_callee(raw_type_token: str) -> bool:
+ """Resolve raw token -> FQCN candidates, then call `_type_token_matches_callee` (CHANGED)."""
+ seen = set()
+ for fq in _iter_fqcn_candidates(raw_type_token):
+ if fq in seen:
+ continue
+ seen.add(fq)
+ if _type_token_matches_callee(fq): # fqcn-only
+ return True
+ return False
+
def _caller_key() -> str:
content = function.page_content
md5hex = hashlib.md5(content.strip().encode("utf-8")).hexdigest()
return f"{extract_method_name_with_params(content)}@{md5hex}@{caller_src}"
- # ---- normalize the found call snippet ----
expr = _strip_return_and_ws(_strip_leading_this(identifier_function))
expr = re.sub(r'\s+', ' ', expr).strip()
def _extract_ctor_type(expr: str) -> str:
- """
- Given an expression that starts with 'new ', return the constructor type token.
- Handles qualified names, generics (including diamond), and array brackets.
- Stops at the first '(' or '{' seen at top level (outside '<...>').
- Examples:
- - 'new Foo(' -> 'Foo'
- - 'new a.b.C(' -> 'a.b.C'
- - 'new HashMap>(' -> 'HashMap>'
- - 'new int[] {' -> 'int[]'
- """
i = 4 # skip 'new '
n = len(expr)
- # skip leading spaces
while i < n and expr[i].isspace():
i += 1
buf = []
@@ -1335,63 +1638,70 @@ def _extract_ctor_type(expr: str) -> str:
buf.append(ch); i += 1; continue
if (ch == '(' or ch == '{') and angle == 0:
break
- # include identifiers, dots, '$', arrays, spaces between [] if present
buf.append(ch); i += 1
return ''.join(buf).strip()
# ====== EARLY: direct constructor snippet 'new Type(...' ======
- # This covers cases where search_for_called_function fed us a constructor occurrence directly.
if expr.startswith('new '):
ctor_type_token = _extract_ctor_type(expr)
- return _type_token_matches_callee(ctor_type_token)
+ return _type_matches_callee(ctor_type_token) # resolve -> FQCN(s)
recv_norm, method_norm = _extract_recv_and_method(expr)
if method_norm:
expr = (f"{recv_norm}.{method_norm}(" if recv_norm else f"{method_norm}(").strip()
- # decide qualified vs unqualified using recv_norm
is_unqualified = (recv_norm is None)
# ======== SUPER / THIS handling via target_class_names ========
if recv_norm is not None:
recv_trim = recv_norm.strip()
- # Interface.super.method(…)
if recv_trim.endswith('.super'):
- iface_type = recv_trim[:-6].strip() # drop ".super"
- if _type_token_matches_callee(iface_type):
+ iface_type = recv_trim[:-6].strip()
+ if _type_matches_callee(iface_type):
return True
- # plain super.method(…)
- if recv_trim == 'super':
- caller_inheritance_list = get_target_class_names(type_inheritance[(caller_class_name, caller_src)])
- for cand in caller_inheritance_list:
- if cand == strip_java_generics(caller_class_name):
- continue
- if _type_token_matches_callee(cand):
- return True
+ if recv_trim == 'super' and callee_declaring_fqcn:
+ # Don't match super calls from root-package functions.
+ # super delegates to the parent's own implementation which is
+ # self-contained. Allowing it at the app→dependency boundary
+ # lets the CCA's backtracking route through the entire
+ # dependency hierarchy via polymorphic dispatch, producing
+ # false-positive chains (e.g. handler.handle() hops).
+ if self.is_root_package(function):
+ pass # skip – super from app code is not a valid entry point
+ else:
+ try:
+ caller_inheritance = type_inheritance[(caller_fqcn, caller_src)]
+ except KeyError:
+ caller_inheritance = []
+ # super resolves to the direct parent class only; match only if the
+ # direct parent IS the callee's declaring class (exact match).
+ for cand_fqcn, _cand_src in caller_inheritance:
+ if cand_fqcn == caller_fqcn:
+ continue
+ if cand_fqcn == callee_declaring_fqcn or cand_fqcn == callee_declaring_fqcn.replace('$', '.'):
+ return True
+ break
- # explicit this.method(…)
if recv_trim == 'this':
- if caller_class_name and _type_token_matches_callee(caller_class_name):
+ if caller_fqcn and _type_token_matches_callee(caller_fqcn):
return True
if is_unqualified:
- # helper method in same class? use its return type
base_before_paren = method_norm
if base_before_paren and callee_function_name != base_before_paren:
fn = find_function(caller_src, base_before_paren, documents_of_functions)
if fn:
ret_type = _find_method_return_type_in_file(fn.page_content, base_before_paren)
- if ret_type and _type_token_matches_callee(ret_type):
+ if ret_type and _type_matches_callee(ret_type):
return True
- # fallback: same-package unqualified call rule
try:
callee_doc = code_documents[callee_function_file_name]
callee_pkg_decl = (self.get_package_name_file(callee_doc) or "").strip()
except KeyError:
- callee_pkg_decl = callee_function_file_name # fallback
+ callee_pkg_decl = callee_function_file_name
caller_pkg_decl = (self.get_package_name_file(caller_doc) or "").strip() if caller_doc else ""
jar_name = extract_jar_name(caller_src or "")
@@ -1399,32 +1709,28 @@ def _extract_ctor_type(expr: str) -> str:
callee_package.__contains__(jar_name)
and callee_pkg_decl
and callee_pkg_decl == caller_pkg_decl
- and strip_java_generics(caller_class_name) in target_class_names
+ and (caller_fqcn in target_class_names if caller_fqcn else False)
)
- # ---- dotted/chained expression handling ----
chain = _split_top_level_dots(expr)
if not chain:
return False
- recv_raw = chain[-2] if len(chain) >= 2 else ""
- recv_raw = recv_raw.strip()
+ recv_raw = (chain[-2] if len(chain) >= 2 else "").strip()
- # Casts first: '((Type) x).method('
cast_type, post_cast_expr = _peel_leading_cast(recv_raw)
if cast_type:
- if _type_token_matches_callee(cast_type):
+ if _type_matches_callee(cast_type):
return True
m_id_cast = re.match(r'\s*([A-Za-z_$][\w$]*)', post_cast_expr)
if m_id_cast:
- caller_key = _caller_key()
traced_cast = self.__trace_down_package(
expression=m_id_cast.group(1),
type_documents=type_documents,
callee_package=callee_package,
fields_of_types=fields_of_types,
functions_local_variables_index=functions_local_variables_index,
- caller_function_index=caller_key,
+ caller_function_index=_caller_key(),
target_class_names=target_class_names,
function=function,
code_documents=code_documents
@@ -1432,33 +1738,41 @@ def _extract_ctor_type(expr: str) -> str:
if traced_cast:
return True
- # Also allow Interface.super receiver in a chain
if recv_raw.endswith('.super'):
iface_type = recv_raw[:-6].strip()
- if _type_token_matches_callee(iface_type):
+ if _type_matches_callee(iface_type):
return True
- # Explicit super/this in chained receiver
- if recv_raw == 'super':
- caller_inheritance_list = get_target_class_names(type_inheritance[(caller_class_name, caller_src)])
- for cand in caller_inheritance_list:
- if cand == strip_java_generics(caller_class_name):
- continue
- if _type_token_matches_callee(cand):
- return True
+ if recv_raw == 'super' and callee_declaring_fqcn:
+ # Don't match super calls from root-package functions (see early-path comment).
+ if self.is_root_package(function):
+ pass # skip
+ else:
+ try:
+ caller_inheritance = type_inheritance[(caller_fqcn, caller_src)]
+ except KeyError:
+ caller_inheritance = []
+ # super resolves to the direct parent class only; match only if the
+ # direct parent IS the callee's declaring class (exact match).
+ for cand_fqcn, _cand_src in caller_inheritance:
+ if cand_fqcn == caller_fqcn:
+ continue
+ if cand_fqcn == callee_declaring_fqcn or cand_fqcn == callee_declaring_fqcn.replace('$', '.'):
+ return True
+ break
elif recv_raw == 'this':
- if caller_class_name and _type_token_matches_callee(caller_class_name):
+ if caller_fqcn and _type_token_matches_callee(caller_fqcn):
return True
- # Case A: any static-style class root in the chain (includes 'new Type(...)' receivers)
+ # Case A: static-style class root in the chain (includes 'new Type(...)' receivers)
for seg in chain[:-1]:
seg_core = _strip_call_parens(seg).strip()
if seg_core.startswith("new "):
typ = _strip_call_parens(seg_core[4:].strip())
- if _type_token_matches_callee(typ):
+ if _type_matches_callee(typ):
return True
continue
- if _is_fqcn_like(seg_core) and _type_token_matches_callee(seg_core):
+ if _is_fqcn_like(seg_core) and _type_matches_callee(seg_core):
return True
# Case B: identifier/expression → dataflow OR rightmost receiver-call return type
@@ -1494,58 +1808,118 @@ def _extract_ctor_type(expr: str) -> str:
recv_call_name = mtok.group(1)
break
if recv_call_name:
- # Use local method return type (cheap heuristic)
ret_type = _find_method_return_type_in_file(caller_text, recv_call_name)
- if ret_type and _type_token_matches_callee(ret_type):
+ if ret_type and _type_matches_callee(ret_type):
return True
# Case C: final receiver is a direct 'new Type(...)'
if recv_raw.startswith("new "):
typ = _strip_call_parens(recv_raw[4:].strip())
- if _type_token_matches_callee(typ):
+ if _type_matches_callee(typ):
return True
# Case D: immediate receiver is a (pkg.)Class token
- if _is_fqcn_like(recv_raw) and _type_token_matches_callee(recv_raw):
+ if _is_fqcn_like(recv_raw) and _type_matches_callee(recv_raw):
return True
return False
- def __trace_down_package(self, expression: str, type_documents: list[Document],
- callee_package: str, fields_of_types: dict[tuple, list[tuple]],
- functions_local_variables_index: dict[str, dict],
- caller_function_index: str,
- target_class_names: frozenset[str],
- function: Document,
- code_documents: dict[str, Document]) -> bool:
-
- variables_mappings = functions_local_variables_index[caller_function_index]
+ def __trace_down_package(
+ self,
+ expression: str,
+ type_documents: list[Document],
+ callee_package: str,
+ fields_of_types: dict[tuple, list[tuple]],
+ functions_local_variables_index: dict[str, dict],
+ caller_function_index: str,
+ target_class_names: frozenset[str],
+ function: Document,
+ code_documents: dict[str, Document],
+ ) -> bool:
+ variables_mappings = functions_local_variables_index.get(caller_function_index, {}) # CHANGED: safe fallback
parts = expression.split(".")
result = False
+ def _normalize_type_token(t: str) -> str:
+ """Strip generics + array suffixes; keep dots/$ for FQCNs."""
+ if not t:
+ return ""
+ t = t.strip()
+ t = strip_java_generics(t).strip()
+ t = re.sub(r"\s*(\[\])+$", "", t).strip()
+ return t
+
+ def _fqcn_candidates_from_token(type_token: str) -> list[str]:
+ """
+ Return FQCN candidates derived from `type_token`, restricted to `target_class_names`.
+ - If token is already an exact member of target_class_names, return it.
+ - Else map by simple-name suffix against target_class_names (handles Inner via '$').
+ """
+ t = _normalize_type_token(type_token)
+ if not t:
+ return []
+
+ # Exact match (already FQCN in allow-list)
+ if t in target_class_names:
+ return [t]
+
+ # Derive simple name and match against allow-list by suffix
+ simple = t.rsplit(".", 1)[-1]
+ out: list[str] = []
+ dot_suffix = "." + simple
+ dollar_suffix = "$" + simple
+ for fq in target_class_names:
+ # fq is FQCN; match on end
+ if fq.endswith(dot_suffix) or fq.endswith(dollar_suffix):
+ out.append(fq)
+
+ return out
+
+ def _has_matching_type(type_token: str) -> bool:
+ """Call __get_type_docs_matched_with_callee_type with FQCNs only."""
+ for fq in _fqcn_candidates_from_token(type_token):
+ if self.__get_type_docs_matched_with_callee_type(
+ callee_package, fq, type_documents, target_class_names
+ ):
+ return True
+ return False
+
(resolved_type, struct_initializer_expression,
value_list, var_properties) = self.__prepare_package_lookup(parts, variables_mappings, the_part=-1)
- if (var_properties is not None
- and (struct_initializer_expression or
- resolved_type not in LOCAL_INDIRECT_TYPES_INDICATIONS or PARAMETER in value_list)):
- result = self.__lookup_package(callee_package, resolved_type, struct_initializer_expression,
- type_documents, value_list, target_class_names)
+ if (
+ var_properties is not None
+ and (
+ struct_initializer_expression
+ or resolved_type not in LOCAL_INDIRECT_TYPES_INDICATIONS
+ or PARAMETER in value_list
+ )
+ ):
+ result = self.__lookup_package(
+ callee_package=callee_package,
+ resolved_type=resolved_type,
+ struct_initializer_expression=struct_initializer_expression,
+ type_documents=type_documents,
+ value_list=value_list,
+ target_class_names=target_class_names
+ )
# Property/member is not in function, check if it's member/property of a type
- elif var_properties is None and (expression[0].islower() or expression[0] == '_'):
- class_name = self.get_class_name_from_class_function(function)
- possible_types = {key: value for (key, value) in fields_of_types.items()
- if key == (class_name, function.metadata['source'])}
+ elif var_properties is None and (expression and (expression[0].islower() or expression[0] == '_')):
+ fqcn = self.get_class_name_from_class_function(function)
+ possible_types = {
+ key: value for (key, value) in fields_of_types.items()
+ if key == (fqcn, function.metadata['source'])
+ }
for mappings in possible_types.values():
for mapping in mappings:
if expression in mapping:
- returned_matched_types = self.__get_type_docs_matched_with_callee_type(callee_package,
- mapping[1],
- type_documents,
- target_class_names)
- if len(returned_matched_types) > 0:
+ # CHANGED: mapping[1] may be simple; resolve to FQCN candidates first
+ if _has_matching_type(mapping[1]):
result = True
+ break
+ if result:
+ break
elif var_properties is not None:
value = var_properties.get("value", None)
@@ -1558,12 +1932,10 @@ def __trace_down_package(self, expression: str, type_documents: list[Document],
code_documents=code_documents
)
if inferred:
- returned_matched_types = self.__get_type_docs_matched_with_callee_type(callee_package,
- inferred,
- type_documents,
- target_class_names)
- if len(returned_matched_types) > 0:
+ # CHANGED: inferred may be simple; resolve to FQCN candidates first
+ if _has_matching_type(inferred):
result = True
+
return result
def document_imports_package(self, documents: dict[str, Document], package_name: str) -> list[Document]:
@@ -2880,64 +3252,240 @@ def __prepare_package_lookup(self, parts, variables_mappings, the_part: int):
else:
return None, None, None, None
- def __lookup_package(self, callee_package, resolved_type, struct_initializer_expression, type_documents,
- value_list, target_class_names: frozenset[str]) -> bool:
- result = False
+ @staticmethod
+ def _normalize_type_token(t: str) -> str:
+ """
+ Normalize a Java type token for matching:
+ - strip leading/trailing whitespace
+ - strip generics
+ - strip trailing array suffixes
+ Keeps dots/$ for FQCNs.
+ """
+ if not t:
+ return ""
+ t = t.strip()
+ t = strip_java_generics(t).strip()
+ t = _ARRAY_SUFFIX_RE.sub("", t).strip()
+ return t
+
+ @staticmethod
+ def _simple_name_from_type_token(t: str) -> str:
+ """
+ Extract the simple name from either a simple type, FQCN, or inner-class FQCN.
+ Uses the last separator among '.' and '$'.
+ """
+ if not t:
+ return ""
+ t = t.strip()
+ i_dot = t.rfind(".")
+ i_dol = t.rfind("$")
+ i = i_dot if i_dot > i_dol else i_dol
+ return t[i + 1:] if i != -1 else t
+
+ @staticmethod
+ @lru_cache(maxsize=256)
+ def _simple_to_fqcns_index(target_class_names: frozenset[str]) -> dict[str, tuple[str, ...]]:
+ """
+ Build an index: simpleName -> tuple(FQCNs). Cached per frozenset instance/value.
+ This avoids O(|target_class_names|) scans per resolution.
+ """
+ idx: dict[str, list[str]] = {}
+ for fq in target_class_names:
+ simple = JavaLanguageFunctionsParser._simple_name_from_type_token(fq)
+ if not simple:
+ continue
+ idx.setdefault(simple, []).append(fq)
+ return {k: tuple(v) for k, v in idx.items()}
+
+ def _fqcn_candidates_from_token(self, type_token: str, target_class_names: frozenset[str]) -> tuple[str, ...]:
+ """
+ Map a possibly-simple type token to FQCN candidates strictly within `target_class_names`.
+ - If token already equals an allowed FQCN => (token,)
+ - Else => all allowed FQCNs that share the same simple name
+ """
+ t = self._normalize_type_token(type_token)
+ if not t:
+ return ()
+ if t in target_class_names:
+ return (t,)
+ simple = self._simple_name_from_type_token(t)
+ if not simple:
+ return ()
+ return self._simple_to_fqcns_index(target_class_names).get(simple, ())
+
+ def _has_matching_type_in_package(
+ self,
+ callee_package: str,
+ type_token: str,
+ type_documents: list[Document],
+ target_class_names: frozenset[str],
+ ) -> bool:
+ """
+ Calls __get_type_docs_matched_with_callee_type with FQCN candidates only.
+ """
+ for fq in self._fqcn_candidates_from_token(type_token, target_class_names):
+ if self.__get_type_docs_matched_with_callee_type(callee_package, fq, type_documents, target_class_names):
+ return True
+ return False
+
+ @staticmethod
+ def _split_top_level_dots(expr: str) -> list[str]:
+ parts, buf = [], []
+ dp = db = dbr = da = 0
+ in_str = in_chr = False
+ i, n = 0, len(expr)
+ while i < n:
+ ch = expr[i]
+ if in_str:
+ buf.append(ch)
+ if ch == '\\' and i + 1 < n:
+ buf.append(expr[i + 1]); i += 1
+ elif ch == '"':
+ in_str = False
+ i += 1
+ continue
+ if in_chr:
+ buf.append(ch)
+ if ch == '\\' and i + 1 < n:
+ buf.append(expr[i + 1]); i += 1
+ elif ch == "'":
+ in_chr = False
+ i += 1
+ continue
+
+ if ch == '"':
+ in_str = True; buf.append(ch); i += 1; continue
+ if ch == "'":
+ in_chr = True; buf.append(ch); i += 1; continue
+ if ch == '(':
+ dp += 1; buf.append(ch); i += 1; continue
+ if ch == ')':
+ dp = max(0, dp - 1); buf.append(ch); i += 1; continue
+ if ch == '[':
+ db += 1; buf.append(ch); i += 1; continue
+ if ch == ']':
+ db = max(0, db - 1); buf.append(ch); i += 1; continue
+ if ch == '{':
+ dbr += 1; buf.append(ch); i += 1; continue
+ if ch == '}':
+ dbr = max(0, dbr - 1); buf.append(ch); i += 1; continue
+ if ch == '<':
+ da += 1; buf.append(ch); i += 1; continue
+ if ch == '>':
+ if da > 0: da -= 1
+ buf.append(ch); i += 1; continue
+
+ if ch == '.' and dp == db == dbr == da == 0:
+ parts.append(''.join(buf).strip()); buf = []; i += 1; continue
+
+ buf.append(ch); i += 1
+
+ parts.append(''.join(buf).strip())
+ return [p for p in parts if p]
+
+ @staticmethod
+ def _strip_call_parens(s: str) -> str:
+ s2 = s.rstrip()
+ if not s2.endswith(')'):
+ return s
+ depth = 0
+ in_str = in_chr = False
+ i = len(s2) - 1
+ while i >= 0:
+ ch = s2[i]
+ if in_str:
+ if ch == '\\': i -= 2; continue
+ if ch == '"': in_str = False; i -= 1; continue
+ i -= 1; continue
+ if in_chr:
+ if ch == '\\': i -= 2; continue
+ if ch == "'": in_chr = False; i -= 1; continue
+ i -= 1; continue
+ if ch == '"': in_str = True; i -= 1; continue
+ if ch == "'": in_chr = True; i -= 1; continue
+ if ch == ')': depth += 1; i -= 1; continue
+ if ch == '(':
+ depth -= 1
+ if depth == 0:
+ return s2[:i].rstrip()
+ i -= 1; continue
+ i -= 1
+ return s
+
+ @staticmethod
+ @lru_cache(maxsize=2048)
+ def _method_sig_re(method_name: str) -> re.Pattern:
+ return re.compile(
+ rf"""
+ (?[A-Za-z_.$?][\w$.<>\[\]?]*?) # return type
+ \s+{re.escape(method_name)}\s*\( # method name + '('
+ """,
+ re.VERBOSE,
+ )
+
+ @staticmethod
+ @lru_cache(maxsize=2048)
+ def _type_header_re(simple_type_name: str) -> re.Pattern:
+ return re.compile(rf"\b(class|interface|enum|record)\b\s+{re.escape(simple_type_name)}\b")
+
+ # -------------------------
+ # Updated functions (DEDUPED)
+ # -------------------------
+
+ def __lookup_package(
+ self,
+ callee_package,
+ resolved_type,
+ struct_initializer_expression,
+ type_documents,
+ value_list,
+ target_class_names: frozenset[str],
+ ) -> bool:
if not struct_initializer_expression and resolved_type not in JAVA_METHOD_PRIM_TYPES:
- docs = self.__get_type_docs_matched_with_callee_type(callee_package, resolved_type, type_documents, target_class_names)
+ if resolved_type and resolved_type not in JAVA_METHOD_PRIM_TYPES:
+ if self._has_matching_type_in_package(
+ callee_package, resolved_type, type_documents, target_class_names
+ ):
+ return True
- if len(docs) > 0:
- result = True
- elif PARAMETER in value_list:
- result = False
+ if PARAMETER in value_list:
+ return False
elif struct_initializer_expression:
- struct_type = (struct_initializer_expression.group(0)) # TODO struct_initializer_expression is a list of expressions
- docs = self.__get_type_docs_matched_with_callee_type(callee_package, struct_type, type_documents, target_class_names)
- if len(docs) > 0:
- result = True
- return result
+ struct_type = struct_initializer_expression.group(0) # TODO list of expressions
+ if self._has_matching_type_in_package(
+ callee_package, struct_type, type_documents, target_class_names
+ ):
+ return True
- def __get_type_docs_matched_with_callee_type(self, callee_package, checked_type, type_documents, target_class_names: frozenset[str]) -> list[
- Document]:# TODO Make sure Fix works - the target classes can be in other jars
- """
- Return all type_documents whose jar "package" matches callee_package and whose
- (generic-stripped) type name is in target_class_names.
+ return False
- Performance notes:
- - Short-circuits early if checked_type cannot possibly match.
- - Normalizes callee_package once.
- - Reuses cached extract_jar_name, get_type_name, and strip_java_generics.
- """
- # Early exit if the checked_type is present and does not match any target class
- if checked_type:
- stripped_checked = strip_java_generics(checked_type)
- if stripped_checked not in target_class_names:
- return []
+ def __get_type_docs_matched_with_callee_type(
+ self,
+ callee_package,
+ checked_type,
+ type_documents,
+ target_class_names: frozenset[str]
+ ) -> list[Document]:
+ checked = strip_java_generics(checked_type or "").strip()
+ if not checked or checked not in target_class_names:
+ return []
- # Normalize the "package" part of callee_package once (after first ':')
_, _, callee_pkg_tail = callee_package.partition(":")
callee_pkg_tail_lower = callee_pkg_tail.lower()
- result: List["Document"] = []
-
+ result: list[Document] = []
for a_type in type_documents:
- src = a_type.metadata['source']
-
- # extract_jar_name is cached and fairly cheap after first call
- jar_name = extract_jar_name(src)
-
- # Same logic as is_same_package(extract_jar_name(...), callee_package.partition(":")[2])
- if jar_name.lower() != callee_pkg_tail_lower:
+ src = a_type.metadata["source"]
+ if extract_jar_name(src).lower() != callee_pkg_tail_lower and self.dir_name_for_3rd_party_packages() not in src :
continue
- # get_type_name is cached per source text
- type_name = get_type_name(a_type.page_content)
- if not type_name:
- continue
-
- # Preserve original logic: strip generics before membership check
- if strip_java_generics(type_name) in target_class_names:
+ fqcn = self.get_class_name_from_class_function(a_type)
+ if fqcn == checked: # FIX: match the requested target, not any target
result.append(a_type)
return result
@@ -2949,145 +3497,74 @@ def _find_method_return_type_in_type_docs(
type_documents: list[Document],
) -> str | None:
"""
- Find the declared return type of `receiver_type#method_name(...)` by scanning
- the matching type document in `type_documents`. Returns a normalized type
- (generics/arrays stripped) or None if not found.
+ Same logic, but reuses cached regex helpers (no duplicated regex construction).
"""
- # Normalize the receiver simple name (e.g. "org.quartz.JobDetail" -> "JobDetail")
- simple = re.sub(r'\s+', '', receiver_type).rsplit('.', 1)[-1]
- simple = re.sub(r'<[^<>]*>', '', simple)
- simple = re.sub(r'\s*(\[\])+$', '', simple)
+ simple = re.sub(r"\s+", "", receiver_type)
+ simple = simple.rsplit(".", 1)[-1]
+ simple = re.sub(r"<[^<>]*>", "", simple)
+ simple = _ARRAY_SUFFIX_RE.sub("", simple)
- # A light signature regex (like the one you already use)
- sig = re.compile(
- rf"""
- (?[A-Za-z_.$?][\w$.<>\[\]?]*?) # return type
- \s+{re.escape(method_name)}\s*\( # method name + '('
- """,
- re.VERBOSE,
- )
+ sig = self._method_sig_re(method_name)
+ header = self._type_header_re(simple)
- # Pick the first doc whose header declares this simple type
- header = re.compile(rf'\b(class|interface|enum|record)\b\s+{re.escape(simple)}\b')
for doc in type_documents:
src = doc.page_content
if not src:
continue
if not header.search(src):
continue
+
m = sig.search(src)
if not m:
- # Could be overloaded; try a looser scan (multiple matches)
matches = list(sig.finditer(src))
if not matches:
continue
m = matches[0]
- ret = m.group('ret').strip()
- ret = re.sub(r'<[^<>]*>', '', ret) # strip generics
- ret = re.sub(r'\s*(\[\])+$', '', ret) # strip array suffixes
+ ret = m.group("ret").strip()
+ ret = re.sub(r"<[^<>]*>", "", ret)
+ ret = _ARRAY_SUFFIX_RE.sub("", ret).strip()
return ret or None
return None
def _infer_type_from_var_initializer(
self,
- initializer: str, # e.g. "getDescriptorById(conn, recording.remoteId)"
- variables_mappings: dict, # per-function locals/params map
+ initializer: str,
+ variables_mappings: dict,
type_documents: list[Document],
- caller_src: str, # function.metadata['source']
+ caller_src: str,
code_documents: dict[str, Document],
) -> str | None:
- def _split_top_level_dots(expr: str) -> list[str]:
- parts, buf = [], []
- dp = db = dbr = da = 0
- in_str = in_chr = False
- i, n = 0, len(expr)
- while i < n:
- ch = expr[i]
- if in_str:
- buf.append(ch)
- if ch == '\\' and i + 1 < n: buf.append(expr[i+1]); i += 1
- elif ch == '"': in_str = False
- i += 1; continue
- if in_chr:
- buf.append(ch)
- if ch == '\\' and i + 1 < n: buf.append(expr[i+1]); i += 1
- elif ch == "'": in_chr = False
- i += 1; continue
-
- if ch == '"': in_str = True; buf.append(ch); i += 1; continue
- if ch == "'": in_chr = True; buf.append(ch); i += 1; continue
- if ch == '(': dp += 1; buf.append(ch); i += 1; continue
- if ch == ')': dp = max(0, dp - 1); buf.append(ch); i += 1; continue
- if ch == '[': db += 1; buf.append(ch); i += 1; continue
- if ch == ']': db = max(0, db - 1); buf.append(ch); i += 1; continue
- if ch == '{': dbr += 1; buf.append(ch); i += 1; continue
- if ch == '}': dbr = max(0, dbr - 1); buf.append(ch); i += 1; continue
- if ch == '<': da += 1; buf.append(ch); i += 1; continue
- if ch == '>':
- if da > 0: da -= 1
- buf.append(ch); i += 1; continue
-
- if ch == '.' and dp == db == dbr == da == 0:
- parts.append(''.join(buf).strip()); buf = []; i += 1; continue
- buf.append(ch); i += 1
- parts.append(''.join(buf).strip())
- return [p for p in parts if p]
-
- def _strip_call_parens(s: str) -> str:
- s2 = s.rstrip()
- if not s2.endswith(')'):
- return s
- depth = 0
- in_str = in_chr = False
- i = len(s2) - 1
- while i >= 0:
- ch = s2[i]
- if in_str:
- if ch == '\\': i -= 2; continue
- if ch == '"': in_str = False; i -= 1; continue
- i -= 1; continue
- if in_chr:
- if ch == '\\': i -= 2; continue
- if ch == "'": in_chr = False; i -= 1; continue
- i -= 1; continue
- if ch == '"': in_str = True; i -= 1; continue
- if ch == "'": in_chr = True; i -= 1; continue
- if ch == ')': depth += 1; i -= 1; continue
- if ch == '(':
- depth -= 1
- if depth == 0: return s2[:i].rstrip()
- i -= 1; continue
- i -= 1
- return s
-
- def _find_method_return_type_in_file(source_text: str, method_name: str) -> str | None:
- sig = re.compile(
+ """
+ Same logic, but uses shared helpers for dot-splitting and call-paren stripping.
+ """
+ @lru_cache(maxsize=1024)
+ def _file_sig_re(method_name: str) -> re.Pattern:
+ return re.compile(
rf"""(?[A-Za-z_.$?][\w$.<>\[\]?]*?)\s+{re.escape(method_name)}\s*\(
- """, re.VERBOSE)
- m = sig.search(source_text)
+ """,
+ re.VERBOSE,
+ )
+
+ def _find_method_return_type_in_file(source_text: str, method_name: str) -> str | None:
+ m = _file_sig_re(method_name).search(source_text or "")
if not m:
return None
- ret = m.group('ret').strip()
- ret = re.sub(r'<[^<>]*>', '', ret)
- ret = re.sub(r'\s*(\[\])+$', '', ret)
+ ret = m.group("ret").strip()
+ ret = re.sub(r"<[^<>]*>", "", ret)
+ ret = _ARRAY_SUFFIX_RE.sub("", ret).strip()
return ret or None
- expr = initializer.strip().rstrip(';')
- chain = _split_top_level_dots(expr)
+ expr = (initializer or "").strip().rstrip(';')
+ chain = self._split_top_level_dots(expr)
if not chain:
return None
- # ----- CASE 1: starts with a variable identifier (old behavior) -----
start = chain[0].strip()
ident_only = re.match(r'^[A-Za-z_$][\w$]*$', start)
if ident_only and (len(chain) == 1 or not start.endswith(')')):
@@ -3096,13 +3573,11 @@ def _find_method_return_type_in_file(source_text: str, method_name: str) -> str
if not start_type:
return None
cur_type = re.sub(r'<[^<>]*>', '', start_type)
- cur_type = re.sub(r'\s*(\[\])+$', '', cur_type)
-
+ cur_type = _ARRAY_SUFFIX_RE.sub("", cur_type).strip()
else:
- # Examine only the "head" before '(' to decide qualification
- head = _strip_call_parens(start).strip()
+ head = self._strip_call_parens(start).strip()
- # ----- CASE 2: unqualified method call in same file: foo(a,b) -----
+ # CASE 2: unqualified method call in same file: foo(a,b)
if start.endswith(')') and '.' not in head:
mname = re.match(r'^([A-Za-z_$][\w$]*)$', head)
if not mname:
@@ -3114,32 +3589,29 @@ def _find_method_return_type_in_file(source_text: str, method_name: str) -> str
if not ret:
return None
cur_type = re.sub(r'<[^<>]*>', '', ret)
- cur_type = re.sub(r'\s*(\[\])+$', '', cur_type)
+ cur_type = _ARRAY_SUFFIX_RE.sub("", cur_type).strip()
- # ----- CASE 3: static-style or qualified head: ClassName.staticMethod(...) or pkg.Class.m(...)
+ # CASE 3: static-style or qualified head: ClassName.staticMethod(...) or pkg.Class.m(...)
elif start.endswith(')') and '.' in head:
- # Split head into type-like left and method right
left, _, right = head.rpartition('.')
- # Left should be a type token (simple or qualified). We don't verify imports here.
base_type = re.sub(r'<[^<>]*>', '', left)
- base_type = re.sub(r'\s*(\[\])+$', '', base_type)
+ base_type = _ARRAY_SUFFIX_RE.sub("", base_type).strip()
method_name = right
if not re.match(r'^[A-Za-z_$][\w$]*$', method_name):
return None
- # Return type of static method on 'base_type'
ret = self._find_method_return_type_in_type_docs(base_type, method_name, type_documents)
if not ret:
return None
cur_type = re.sub(r'<[^<>]*>', '', ret)
- cur_type = re.sub(r'\s*(\[\])+$', '', cur_type)
+ cur_type = _ARRAY_SUFFIX_RE.sub("", cur_type).strip()
else:
return None
- # ----- Walk remaining segments; method calls change the type -----
+ # Walk remaining segments; method calls change the type
for seg in chain[1:]:
- seg_core = _strip_call_parens(seg).strip()
- if not seg.endswith(')'): # field access — keep type
+ seg_core = self._strip_call_parens(seg).strip()
+ if not seg.endswith(')'):
continue
mname = re.match(r'^([A-Za-z_$][\w$]*)\s*$', seg_core)
if not mname:
@@ -3149,9 +3621,10 @@ def _find_method_return_type_in_file(source_text: str, method_name: str) -> str
if not ret:
return None
cur_type = re.sub(r'<[^<>]*>', '', ret)
- cur_type = re.sub(r'\s*(\[\])+$', '', cur_type)
+ cur_type = _ARRAY_SUFFIX_RE.sub("", cur_type).strip()
return cur_type or None
def get_package_name(self, function: Document, package_name: str) -> str:
- return package_name if extract_jar_name(function.metadata['source']) in package_name else ''
\ No newline at end of file
+ jar_name = extract_jar_name(function.metadata['source'])
+ return package_name if jar_name and jar_name in package_name else ''
\ No newline at end of file
diff --git a/src/exploit_iq_commons/utils/java_chain_of_calls_retriever.py b/src/exploit_iq_commons/utils/java_chain_of_calls_retriever.py
index 33c43afc..8a6cadad 100644
--- a/src/exploit_iq_commons/utils/java_chain_of_calls_retriever.py
+++ b/src/exploit_iq_commons/utils/java_chain_of_calls_retriever.py
@@ -14,8 +14,9 @@
)
from exploit_iq_commons.logging.loggers_factory import LoggingFactory
-from exploit_iq_commons.utils.java_utils import convert_from_maven_artifact, extract_jar_name, is_maven_gav, extract_method_name_with_params, \
- create_inheritance_map, get_target_class_names, dummy_package_name
+from exploit_iq_commons.utils.java_utils import convert_from_maven_artifact, extract_jar_name, is_maven_gav, \
+ extract_method_name_with_params, \
+ create_inheritance_map, get_target_class_names, dummy_package_name, extract_fqcn
from exploit_iq_commons.data_models.input import SourceDocumentsInfo
logger = LoggingFactory.get_agent_logger(f"morpheus.{__name__}")
@@ -81,7 +82,7 @@ def __init__(self, documents: List[Document],
logger.debug("Chain of Calls Retriever - populating functions documents")
allowed_files_extensions = self.language_parser.supported_files_extensions()
- # filter out unsupported files extensions.
+ # filter out unsupported files extensions and test files
documents = [doc for doc in documents
if any([ext for ext in allowed_files_extensions if str(doc.metadata['source'])
.endswith(ext)])]
@@ -175,18 +176,21 @@ def __find_caller_function(self, document_function: Document, function_package:
method_exclusions = func_pack_from_tree[METHOD_EXCLUSIONS_INDEX]
target_class_names: frozenset[str]
- class_name = self.language_parser.get_class_name_from_class_function(document_function)
- key = (class_name, document_function.metadata['source'])
+
+ fqcn = self.language_parser.get_class_name_from_class_function(document_function)
if "dummy" not in function_file_name:
+ key = (fqcn, function_file_name)
target_class_names = get_target_class_names(self.type_inheritance[key])
else:
- target_class_names = frozenset([class_name])
- target_type_doc = Document(page_content="public class " + class_name + "{}",
+ fqcn_no_dummy = fqcn.replace(dummy_package_name, "")
+ target_class_names = frozenset([fqcn_no_dummy])
+ target_type_doc = Document(page_content="public class " + fqcn.rpartition('.')[2] + "{}",
metadata={"source": self.language_parser.dir_name_for_3rd_party_packages() +
- "/" + function_package.partition(':')[-1] +
- "/" + class_name + ".java",
+ "/" + fqcn_no_dummy.replace(".", "/") +
+ ".java",
"ecosystem": self.ecosystem})
documents_of_types.append(target_type_doc)
+ document_function = target_type_doc
# Search for caller functions only at parents according to dependency tree.
for package in direct_parents[last_visited_package_index:]:
@@ -410,15 +414,15 @@ def get_relevant_documents(self, query: str) -> tuple[List[Document], bool]:
return matching_documents, self.found_path
def extract_from_query(self, query: str) -> tuple[str, str, str]:
- (package_name, function) = tuple( query.splitlines()[0].strip('"\'').replace("#", ".").split(","))
+ (package_name, function) = tuple( query.splitlines()[0].strip('"\'\u2018\u2019\u201c\u201d').replace("#", ".").split(","))
class_name = function.rpartition('.')[0]
- method_name = function.rpartition('.')[2]
+ method_name = re.sub(r'\(.*\)$', '', function.rpartition('.')[2])
if not class_name and self.is_java_fqcn(package_name):
class_name = package_name
- package_name, class_name = self.infer_class_name_and_package_name(method_name, class_name)
+ package_name, class_name = self.infer_class_name_and_package_name(method_name, class_name, package_name)
return class_name, method_name, package_name
@@ -491,10 +495,10 @@ def __find_initial_function(self, class_name: str, method_name: str, package_nam
jar_name = convert_from_maven_artifact(package_name)
relevant_docs = [doc for doc in documents
- if (self.language_parser.dir_name_for_3rd_party_packages() in doc.metadata.get('source')
- and jar_name in doc.metadata.get('source') or
+ if ((self.language_parser.dir_name_for_3rd_party_packages() in doc.metadata.get('source')
+ and jar_name in doc.metadata.get('source')) or
not self.language_parser.dir_name_for_3rd_party_packages() in doc.metadata.get('source')) and
- self.language_parser.get_class_name_from_class_function(doc) == class_name.rpartition('.')[2] and
+ self.language_parser.get_class_name_from_class_function(doc) == class_name and
self.language_parser.get_function_name(doc) == method_name]
package_exclusions = self.tree_dict.get(package_name)[EXCLUSIONS_INDEX]
# TODO handle method overloading
@@ -675,55 +679,27 @@ def is_java_fqcn(self, s: str) -> bool:
return _FQCN_STRICT_RE.match(s) is not None
- def infer_class_name_and_package_name(self, method_name: str, class_name: str) -> (str, str):
- simple_class_name = class_name
- parts = class_name.split(".")
- # if fqcn, take the simple class name
- if len(parts) > 1:
- simple_class_name = parts[-1]
+ def infer_class_name_and_package_name(self, method_name: str, class_name: str, package_name: str) -> (str, str):
+ maven_gav = is_maven_gav(package_name)
for doc in self.documents_of_functions:
if method_name in doc.page_content and method_name == self.language_parser.get_function_name(doc):
- if self.language_parser.get_class_name_from_class_function(doc) == simple_class_name:
- source = doc.metadata['source']
+ doc_fqcn = self.language_parser.get_class_name_from_class_function(doc)
+ if class_name in doc_fqcn:
+ source = doc.metadata['source']
artifact_name = self.extract_maven_artifact(source)
- for val in self.tree_dict:
- if artifact_name in val:
- return val, self.extract_fqcn(source)
-
- return "", class_name
-
- def extract_fqcn(self, text: str) -> str:
- """
- Convert a source file path to an FQCN.
- Works both when a '-sources/' prefix exists and when it doesn't.
-
- Examples:
- 'dependencies-sources/hibernate-core-6.6.13.Final-sources/org/hibernate/type/descriptor/java/ArrayJavaType.java'
- -> 'org.hibernate.type.descriptor.java.ArrayJavaType'
- 'org/hibernate/type/descriptor/java/ArrayJavaType.java'
- -> 'org.hibernate.type.descriptor.java.ArrayJavaType'
- """
-
- p = text.replace("\\", "/")
-
- # If there's a '-sources/' prefix, drop everything up to and including the last occurrence.
- marker = "-sources/"
- cut = p.rfind(marker)
- tail = p[cut + len(marker):] if cut != -1 else p
-
- # Strip leading slash if any
- if tail.startswith("/"):
- tail = tail[1:]
-
- # Drop the .java suffix (case-sensitive per your example; change to lower() if needed)
- if tail.endswith(".java"):
- tail = tail[:-5]
+ # If the library is provided make sure the doc belongs to library
+ if maven_gav and artifact_name in package_name:
+ return package_name, doc_fqcn
+ # If no library is provided look for the function
+ elif not maven_gav:
+ for val in self.tree_dict:
+ if artifact_name in val:
+ return val, doc_fqcn
- # Convert path separators to dots to form the FQCN
- return tail.replace("/", ".")
+ return package_name, class_name
def extract_maven_artifact(self, path: str) -> str:
"""
diff --git a/src/exploit_iq_commons/utils/java_segmenters_with_methods.py b/src/exploit_iq_commons/utils/java_segmenters_with_methods.py
index 6b54c00d..428114b4 100644
--- a/src/exploit_iq_commons/utils/java_segmenters_with_methods.py
+++ b/src/exploit_iq_commons/utils/java_segmenters_with_methods.py
@@ -1,26 +1,366 @@
+from __future__ import annotations
+
import bisect
import os
+import re
-from typing import List, Tuple, Optional, Dict, Set
+from typing import List, Tuple, Optional, Dict, FrozenSet
from langchain_community.document_loaders.parsers.language.java import JavaSegmenter
+_TEST_METHOD_ANNOTATIONS: FrozenSet[str] = frozenset({
+ # JUnit 4/5
+ "Test",
+ "ParameterizedTest",
+ "RepeatedTest",
+ "TestFactory",
+ "TestTemplate",
+ # TestNG
+ "org.testng.annotations.Test", # handled by simple-name "Test" too
+})
+
+_TEST_LIFECYCLE_ANNOTATIONS: FrozenSet[str] = frozenset({
+ # JUnit 4
+ "Before",
+ "After",
+ "BeforeClass",
+ "AfterClass",
+ # JUnit 5
+ "BeforeEach",
+ "AfterEach",
+ "BeforeAll",
+ "AfterAll",
+ # TestNG
+ "BeforeMethod",
+ "AfterMethod",
+ "BeforeClass",
+ "AfterClass",
+ "BeforeSuite",
+ "AfterSuite",
+ "BeforeTest",
+ "AfterTest",
+ "BeforeGroups",
+ "AfterGroups",
+})
+
+# Class-level "this is a test class" markers (Quarkus + common Spring test slices)
+_TEST_CLASS_ANNOTATIONS: FrozenSet[str] = frozenset({
+ "QuarkusTest",
+
+ # Spring / Spring Boot test slices (common ones)
+ "SpringBootTest",
+ "WebMvcTest",
+ "WebFluxTest",
+ "DataJpaTest",
+ "DataJdbcTest",
+ "DataMongoTest",
+ "JdbcTest",
+ "JooqTest",
+ "JsonTest",
+ "RestClientTest",
+ "GraphQlTest",
+ "ContextConfiguration",
+ "SpringJUnitConfig",
+ "SpringJUnitWebConfig",
+ "AutoConfigureMockMvc",
+ "AutoConfigureWebTestClient",
+
+ # Older Spring JUnit 4 runners/config
+ "RunWith", # handled specially (needs SpringRunner)
+ "ExtendWith", # handled specially (needs SpringExtension)
+})
+
+# Scan only these chars; everything else can be skipped in big jumps.
+# (We deliberately include '-' because we only care about "->".)
+_LAMBDA_SCAN_SPECIAL_RE = re.compile(r'[/"\'-]')
+
+_pkg_imp_re = re.compile(r"(?m)^[ \t]*(?:package|import)\s+[^;\n]+;\s*\n")
class JavaSegmenterWithMethods(JavaSegmenter):
def __init__(self, code: str):
super().__init__(code)
+ # ------------------------------ test filtering --------------------------------
+
+ def _looks_like_template_source(self, src: str) -> bool:
+ """
+ Return True if the source contains template-language markers in *code*.
+
+ This scanner ignores:
+ - Java // line comments
+ - Java /* ... */ block/Javadoc comments
+ - Java string literals "..."
+ - Java char literals '...'
+ - Java text blocks \"\"\" ... \"\"\" (Java 15+)
+
+ Therefore, template markers that occur *only* inside strings/comments/text blocks
+ will NOT cause the file to be skipped.
+ """
+ # Cheap prefilter
+ if "<" not in src and "$" not in src:
+ return False
+
+ n = len(src)
+ i = 0
+
+ while i < n:
+ ch = src[i]
+
+ # ---- comments ----
+ if ch == '/' and i + 1 < n:
+ nxt = src[i + 1]
+ if nxt == '/':
+ # line comment
+ i = src.find('\n', i + 2)
+ if i == -1:
+ return False
+ i += 1
+ continue
+ if nxt == '*':
+ # block/Javadoc comment
+ end = src.find('*/', i + 2)
+ if end == -1:
+ return False
+ i = end + 2
+ continue
+
+ # ---- text blocks / strings / chars ----
+ if src.startswith('"""', i):
+ # text block
+ i += 3
+ while i < n:
+ j = src.find('"""', i)
+ if j == -1:
+ return False
+ # if escaped \"\"\", keep scanning
+ k = j - 1
+ backslashes = 0
+ while k >= 0 and src[k] == '\\':
+ backslashes += 1
+ k -= 1
+ if (backslashes % 2) == 1:
+ i = j + 1
+ continue
+ i = j + 3
+ break
+ continue
+
+ if ch == '"':
+ # regular string
+ i += 1
+ while i < n:
+ if src[i] == '\\':
+ i += 2
+ continue
+ if src[i] == '"':
+ i += 1
+ break
+ i += 1
+ continue
+
+ if ch == "'":
+ # char literal
+ i += 1
+ while i < n:
+ if src[i] == '\\':
+ i += 2
+ continue
+ if src[i] == "'":
+ i += 1
+ break
+ i += 1
+ continue
+
+ # ---- marker detection in code only ----
+ # Check the handful of marker prefixes cheaply.
+ if ch == '$' and i + 1 < n and src[i + 1] == '{':
+ return True
+ if ch == '<' and i + 1 < n:
+ c1 = src[i + 1]
+ if c1 in ('#', '@'):
+ return True
+ if c1 == '/' and i + 2 < n and src[i + 2] == '#':
+ return True
+
+ i += 1
+
+ return False
+
+ def _read_qualified_ident_after_at(self, src: str, at_idx: int, end: int) -> Tuple[int, str, str]:
+ """
+ Parse an annotation name starting at '@' (src[at_idx] == '@') within [at_idx, end).
+
+ Returns
+ -------
+ (next_idx, full_name, simple_name)
+
+ - full_name can include dots: org.junit.jupiter.api.Test
+ - simple_name is the last segment: Test
+ """
+ i = at_idx + 1
+ n = min(end, len(src))
+
+ # Special-case annotation type declaration: '@interface' is not a "use" annotation.
+ if i + 8 <= n and src.startswith("interface", i):
+ return i + 9, "interface", "interface"
+
+ # Read first identifier segment
+ if i >= n or not (_is_ident_start(src[i]) or src[i].isalpha()):
+ return at_idx + 1, "", ""
+
+ start = i
+ i += 1
+ while i < n and _is_ident_part(src[i]):
+ i += 1
+
+ # Read optional ".segment" repeats
+ while i < n and src[i] == '.':
+ dot = i
+ j = dot + 1
+ if j < n and (_is_ident_start(src[j]) or src[j].isalpha()):
+ i = j + 1
+ while i < n and _is_ident_part(src[i]):
+ i += 1
+ else:
+ break
+
+ full_name = src[start:i]
+ simple_name = full_name.rsplit(".", 1)[-1]
+ return i, full_name, simple_name
+
+ def _source_is_test_source(self, src: str) -> bool:
+ """Heuristic test-source detector.
+
+ Returns True iff `src` *appears* to be a test class/source.
+
+ Detects:
+ - JUnit 4/5 + TestNG method/lifecycle annotations (e.g. @Test, @BeforeEach)
+ - JUnit3-style `extends TestCase`
+ - Quarkus / Spring / Spring Boot test classes via `_TEST_CLASS_ANNOTATIONS`
+
+ Notes:
+ - Ignores tokens inside comments and string/text blocks.
+ - Designed to be fast: a cheap substring prefilter avoids scanning most sources.
+ """
+ # Fast prefilter:
+ # - If there is no '@', the only supported positive is JUnit3-style `extends ... TestCase`.
+ # - If '@' exists, avoid scanning typical non-test sources by requiring at least one
+ # test-ish keyword to appear somewhere in the raw text.
+ if "@" not in src:
+ if "TestCase" not in src or "extends" not in src:
+ return False
+ else:
+ # Important: test-class annotations are often present even if no @Test is.
+ # `Test` is a broad but capitalized token that strongly correlates with test annotations.
+ if ("Test" not in src and
+ "RunWith" not in src and
+ "ExtendWith" not in src and
+ "ContextConfiguration" not in src and
+ "SpringJUnit" not in src and
+ "AutoConfigureMockMvc" not in src and
+ "AutoConfigureWebTestClient" not in src and
+ "org.junit" not in src and
+ "org.testng" not in src):
+ return False
+
+ n = len(src)
+ i = 0
+ saw_extends = False
+
+ while i < n:
+ i = _skip_ws_comments(src, i)
+ if i >= n:
+ break
+
+ ch = src[i]
+
+ # Skip any string/char/text-block literal quickly.
+ if ch == '"' or ch == "'" or (ch == 'r' and i + 1 < n and src[i + 1] == '"'):
+ i = _skip_java_string_like(src, i)
+ continue
+
+ if ch == '@':
+ # Parse annotation name (qualified or simple).
+ ni, _full, simple = self._read_qualified_ident_after_at(src, i, n)
+
+ # Avoid misclassifying annotation type declarations: `public @interface Foo { ... }`
+ if simple == "interface":
+ i = ni
+ continue
+
+ # JUnit/TestNG method-level / lifecycle annotations.
+ if simple in _TEST_METHOD_ANNOTATIONS or simple in _TEST_LIFECYCLE_ANNOTATIONS:
+ return True
+
+ # Quarkus / Spring / Spring Boot class-level test annotations.
+ if simple in _TEST_CLASS_ANNOTATIONS:
+ # Only treat generic @ExtendWith/@RunWith as a "Spring test" if they bind
+ # the Spring extension/runner; otherwise they are too generic.
+ if simple == "ExtendWith":
+ j = _skip_ws_comments(src, ni)
+ if j < n and src[j] == '(':
+ close = _match_balanced_parens(src, j)
+ if close != -1 and "SpringExtension" in src[j:close]:
+ return True
+ i = ni
+ continue
+
+ if simple == "RunWith":
+ j = _skip_ws_comments(src, ni)
+ if j < n and src[j] == '(':
+ close = _match_balanced_parens(src, j)
+ if close != -1 and "SpringRunner" in src[j:close]:
+ return True
+ i = ni
+ continue
+
+ return True
+
+ # Skip annotation argument list (if any) to avoid rescanning big nested blocks.
+ i = _skip_ws_comments(src, ni)
+ if i < n and src[i] == '(':
+ close = _match_balanced_parens(src, i)
+ i = (close + 1) if close != -1 else (i + 1)
+ continue
+
+ # Detect JUnit3 style `extends ... TestCase` (ignoring comments/strings).
+ if _is_ident_start(ch):
+ j = i + 1
+ while j < n and _is_ident_part(src[j]):
+ j += 1
+ ident = src[i:j]
+
+ if ident == "extends":
+ saw_extends = True
+ elif saw_extends:
+ # Allow qualified form: `extends junit.framework.TestCase`
+ if ident.endswith("TestCase") and ident.rsplit(".", 1)[-1] == "TestCase":
+ return True
+ saw_extends = False
+
+ i = j
+ continue
+
+ i += 1
+
+ return False
+
def extract_functions_classes(self) -> List[str]:
"""
This method extracts all the golang methods and anonymous functions from a source code, and appends them to
- the list of all golang regular functions calculated in parent class' method.
- :
- :return: a list of all golang functions/methods/anonymous functions ( each one with a complete implementation,
+ the list of all java regular functions calculated in parent class' method.
+ :return: a list of all java functions/methods/anonymous functions ( each one with a complete implementation,
signature + body)
"""
+ if self._looks_like_template_source(self.code):
+ return []
+
+ if self._source_is_test_source(self.code):
+ return []
+
function_classes = super().extract_functions_classes()
- methods = extract_methods(self.code) # TODO Add constructors to the method list
+ methods = extract_methods(self.code)
inner_classes = extract_inner_classes(self.code) # TODO Add anonymous inner classes support
function_classes.extend(methods)
function_classes.extend(inner_classes)
@@ -876,108 +1216,32 @@ def _match_balanced_braces(source: str, open_brace_index: int) -> int:
# ------------------------------ backward helpers ------------------------------
-def _prev_word(src: str, pos: int) -> str:
- j = pos - 1
- while j >= 0 and src[j] in ' \t\r\n\f':
- j -= 1
- if j < 0:
- return ""
- end = j + 1
- while j >= 0 and _is_ident_part(src[j]):
- j -= 1
- return src[j+1:end]
-
-def _skip_leading_annotations(src: str, i: int) -> int:
- """
- From position i (roughly at the beginning of a signature), skip any number of
- annotations and comments so that method-level annotations do NOT appear in the
- returned snippet. Returns first non-annotation token index.
- """
- n = len(src)
- i = _skip_ws_comments(src, i)
- while i < n:
- if src[i] == '@':
- i += 1
- while i < n and (src[i].isalnum() or src[i] in ('_', '$', '.')):
- i += 1
- i = _skip_ws_comments(src, i)
- if i < n and src[i] == '(':
- i = _match_balanced_parens(src, i) + 1
- i = _skip_ws_comments(src, i)
- continue
- j = _skip_ws_comments(src, i)
- if j != i:
- i = j
- continue
- break
- return i
-
-def _back_ident(src: str, i: int) -> Tuple[int, Optional[str]]:
- """
- From position i (typically the '(' of a parameter list), scan left to find the
- simple identifier just before it. Returns (start_index, name) or (i, None).
+def _strip_java_noise(snippet: str) -> str:
"""
- j = i - 1
- # skip ws
- while j >= 0 and src[j] in ' \t\r\n\f':
- j -= 1
- if j < 0:
- return i, None
- end = j + 1
- while j >= 0 and _is_ident_part(src[j]):
- j -= 1
- start = j + 1
- name = src[start:end]
- if name and _is_ident_start(name[0]):
- return start, name
- return i, None
-
-def _find_left_boundary(src: str, before: int) -> int:
+ Remove comments, javadocs, annotations, and (defensively) package/import lines.
+ Designed to be fast: cheap substring checks first, heavy work only if needed.
"""
- Return index just after the nearest real boundary (';', '{', '}') occurring before `before`,
- ignoring comments/strings and jumping over balanced () for speed.
- IMPORTANT: Treat '}', '{', ';' as immediate boundaries BEFORE any jump, so we don't leap
- across the previous method body and accidentally glue multiple methods together.
- """
- if _DELIM_SRC_ID != id(src):
- # simple forward fallback
- i = 0
- last = 0
- before = max(0, min(before, len(src)))
- while i < before:
- j = _skip_ws_comments(src, i)
- if j != i:
- i = j; continue
- if i < before and (src.startswith('"""', i) or src[i] in ('"', "'")):
- i = _skip_java_string_like(src, i); continue
- if src[i] in (';', '{', '}'):
- last = i + 1
- i += 1
- return last
-
- # fast backward using indexes
- i = min(before - 1, len(src) - 1)
- while i >= 0:
- # hop out of ignored span
- span = _ignored_span_at(i)
- if span:
- i = span[0] - 1
- continue
+ if not snippet:
+ return snippet
- ch = src[i]
+ # Fast path: almost all snippets do NOT need stripping.
+ if (
+ "/*" not in snippet
+ and "//" not in snippet
+ and "@" not in snippet
+ and "package" not in snippet
+ and "import" not in snippet
+ ):
+ return snippet
- # FIRST: boundary check
- if ch in (';', '{', '}'):
- return i + 1
+ # Reuse existing (previously unused) stripper: removes comments + annotations safely.
+ snippet = _strip_body_comments_and_annotations(snippet, 0, len(snippet))
- # jump across groups for speed (do NOT jump across '}' – handled above)
- if ch == ')':
- o = _PAR_C2O.get(i) if _DELIM_SRC_ID == id(src) else None
- i = (o - 1) if o is not None else (i - 1)
- continue
+ # Remove package/import lines if present (defensive; inner-class extraction used to prepend imports).
+ if "package" in snippet or "import" in snippet:
+ snippet = _pkg_imp_re.sub("", snippet)
- i -= 1
- return 0
+ return snippet
def _match_paren_backwards(source: str, close_paren_index: int) -> int:
"""
@@ -1103,75 +1367,6 @@ def __init__(self, name: str, start: int, end: int) -> None:
self.start = start
self.end = end
-def _find_all_named_types(src: str) -> List[_TypeRegion]:
- """
- Find all class/interface/enum/record bodies and names, including nested types.
- """
- n = len(src)
- i = 0
- types: List[_TypeRegion] = []
- while i < n:
- i = _skip_ws_comments(src, i)
- if i >= n: break
-
- # skip literals
- if src.startswith('"""', i) or src[i] in ('"', "'"):
- i = _skip_java_string_like(src, i); continue
-
- # token
- if _is_ident_start(src[i]) or src[i] == '@':
- tok_start = i + (1 if src[i] == '@' else 0)
- j = tok_start
- while j < n and _is_ident_part(src[j]): j += 1
- tok = src[tok_start:j]
-
- if tok in ('class', 'interface', 'enum', 'record'):
- # name
- k = _skip_ws_comments(src, j)
- name_start = k
- while k < n and _is_ident_part(src[k]): k += 1
- if k == name_start: # '.class' or broken
- i = j; continue
- name = src[name_start:k]
-
- # type params / record header
- k = _skip_ws_comments(src, k)
- if k < n and src[k] == '<':
- k = _match_balanced_angles(src, k) + 1
- k = _skip_ws_comments(src, k)
- if tok == 'record' and k < n and src[k] == '(':
- k = _match_balanced_parens(src, k) + 1
- k = _skip_ws_comments(src, k)
-
- # to opening '{'
- while k < n and src[k] != '{':
- if src.startswith('"""', k) or src[k] in ('"', "'"):
- k = _skip_java_string_like(src, k)
- else:
- k += 1
- k = _skip_ws_comments(src, k)
-
- if k < n and src[k] == '{':
- end = _match_balanced_braces(src, k)
- if name:
- types.append(_TypeRegion(name=name, start=k, end=end))
- i = k + 1
- continue
-
- i = j
- continue
-
- i += 1
-
- return types
-
-def _enclosing_type_name(types: List[_TypeRegion], pos: int) -> Optional[str]:
- best: Optional[_TypeRegion] = None
- for t in types:
- if t.start <= pos <= t.end and (best is None or t.start >= best.start):
- best = t
- return best.name if best else None
-
# --------------------------------- lambdas ------------------------------------
def _capture_lambda_at_arrow(src: str, arrow: int) -> Optional[Tuple[int, int]]:
@@ -1242,403 +1437,1204 @@ def _capture_lambda_at_arrow(src: str, arrow: int) -> Optional[Tuple[int, int]]:
end = k
return (param_start, end)
-def _extract_lambdas_in_range(src: str, start: int, end: int) -> List[str]:
+def _extract_lambdas_in_range(src: str, start: int, end: int) -> List[str]:
+ """
+ Extract Java lambda expressions within [start, end).
+
+ Optimized for huge files:
+ - Fast early-out if there's no '->' at all in the range.
+ - Uses a regex "find next interesting char" jump-table so we don't
+ call helpers per-character.
+
+ Semantics are intended to match the previous implementation:
+ - Skip // and /*...*/ comments
+ - Skip strings/chars/text blocks via _skip_java_string_like()
+ - Capture lambdas via _capture_lambda_at_arrow()
+ - Strip comments in the captured lambda via _strip_comments_keep_strings_range()
+ """
out: List[str] = []
- i = start
n = len(src)
- end = min(end, n)
- while i < end:
- j = _skip_ws_comments(src, i)
- if j != i:
- i = j; continue
- if i < end and (src.startswith('"""', i) or src[i] in ('"', "'")):
- i = _skip_java_string_like(src, i); continue
- if i + 1 < end and src[i] == '-' and src[i + 1] == '>':
- span = _capture_lambda_at_arrow(src, i)
- if span:
- s, e = span
- if s >= start and e <= end:
- clean = _strip_comments_keep_strings_range(src, s, e)
- out.append(clean)
- i = e
+ if start < 0:
+ start = 0
+ if end > n:
+ end = n
+ if start >= end:
+ return out
+
+ # Cheap win for big files with no lambdas at all (like many autogenerated sources):
+ if src.find("->", start, end) == -1:
+ return out
+
+ i = start
+ while True:
+ m = _LAMBDA_SCAN_SPECIAL_RE.search(src, i, end)
+ if not m:
+ break
+
+ i = m.start()
+ ch = src[i]
+
+ # Comments
+ if ch == '/':
+ if i + 1 < end:
+ nxt = src[i + 1]
+ # line comment
+ if nxt == '/':
+ nl = src.find('\n', i + 2, end)
+ i = end if nl == -1 else nl + 1
+ continue
+ # block / javadoc comment
+ if nxt == '*':
+ j = src.find('*/', i + 2, end)
+ i = end if j == -1 else j + 2
+ continue
+ i += 1
+ continue
+
+ # Strings / chars / text blocks
+ if ch in ('"', "'"):
+ i = _skip_java_string_like(src, i)
+ continue
+
+ # Potential lambda arrow
+ if ch == '-':
+ if i + 1 < end and src[i + 1] == '>':
+ span = _capture_lambda_at_arrow(src, i)
+ if span:
+ s, e = span
+ if s >= start and e <= end:
+ clean = src[s:e]
+ clean = _strip_java_noise(clean)
+ out.append(clean)
+ i = e
+ continue
+ i += 2
continue
- i += 2
+
+ i += 1
continue
- i += 1
- return out
-# ------------------ small helpers for method detection ------------------------
+ # Fallback (should be unreachable due to regex character class)
+ i += 1
-def _header_closes_at_brace(src: str, close_p: int, brace_idx: int) -> bool:
- """Return True if src[close_p] == ')' and after optional ws/comments + throws we land exactly on '{'."""
- k = _skip_ws_comments(src, close_p + 1)
- k = _skip_throws_clause(src, k)
- k = _skip_ws_comments(src, k)
- return k == brace_idx
+ return out
# ------------------------------- core extractor -------------------------------
def _extract_methods_anywhere(
src: str,
start: int = 0,
- end: Optional[int] = None,
- types: Optional[List[_TypeRegion]] = None,
-) -> List[str]:
+ end: int | None = None,
+ types: list | None = None,
+) -> list[str]:
"""
Extract Java method definitions (with bodies) anywhere within [start, end),
- INCLUDING constructors. Anonymous/inner/local class bodies are recursively scanned.
- Method-level annotations preceding a signature are excluded from the returned slice.
- Lambdas are NOT returned here (they are gathered in a separate pass).
-
- Performance:
- - Single forward scan with O(1) delimiter lookups when _set_delim_index() cache
- is active.
- - Avoids quadratic behavior by skipping ignored spans and using balanced matchers.
+ INCLUDING constructors. Anonymous/inner/local class bodies are scanned.
+
+ Key behavior:
+ - For named inner types, emitted comment includes ALL nesting levels
+ EXCLUDING the top-level declaring type:
+ e.g. "Outer.Inner1.Inner2" -> "Inner1.Inner2",
+ "CreateMultipartUploadRequest.BuilderImpl" -> "BuilderImpl".
+
+ Annotation behavior:
+ - Lines that contain ONLY annotations (and whitespace/comments) immediately
+ preceding a member are excluded from the returned slice.
+
+ Enum behavior (critical fix):
+ - Do NOT treat enum constant initializers as class members.
+ Method parsing inside a named enum body is disabled until the FIRST ';'
+ at that enum body's brace depth (the end of the enum-constant section).
+
+ Notes on performance:
+ - Single forward scan.
+ - Balanced matchers are used only for candidates at class-member level.
+ - No regex usage in hot paths.
"""
- if end is None:
- end = len(src)
- if types is None:
- types = _find_all_named_types(src)
-
- # Enclosing named-type lookup (innermost)
- types_sorted = sorted(types, key=lambda t: t.start)
- type_starts = [t.start for t in types_sorted]
-
- def _enclosing_region(pos: int) -> Optional[_TypeRegion]:
- """
- Return the innermost named type region containing pos.
- IMPORTANT: keep scanning left until we either find an enclosing region
- or run out of candidates. Do NOT early-return None if the nearest
- left type ended before pos; an outer type may still enclose pos.
- """
- idx = bisect.bisect_right(type_starts, pos) - 1
- while idx >= 0:
- t = types_sorted[idx]
- if t.start <= pos <= t.end:
- return t
- # If this candidate doesn't enclose pos, step left and keep looking.
- idx -= 1
- return None
- # Quick enum-constants section end cache (keyed by enum body '{' position)
- enum_term_cache: Dict[int, int] = {}
-
- def _region_kind(reg: _TypeRegion) -> str:
- if not reg:
- return ""
- i2 = _find_left_boundary(src, reg.start)
- while i2 < reg.start:
- j2 = _skip_ws_comments(src, i2); i2 = j2
- if i2 >= reg.start: break
- if src.startswith('"""', i2) or src[i2] in ('"', "'"):
- i2 = _skip_java_string_like(src, i2); continue
- if _is_ident_start(src[i2]) or src[i2] == '@':
- tok_start = i2 + (1 if src[i2] == '@' else 0)
- j2 = tok_start
- while j2 < reg.start and _is_ident_part(src[j2]): j2 += 1
- tok = src[tok_start:j2]
- if tok in ('class', 'interface', 'enum', 'record'):
- return tok
- i2 = j2; continue
- i2 += 1
- return ""
+ # -------------------------
+ # Small internal structs
+ # -------------------------
+ class _TypeRegion:
+ __slots__ = ("kind", "name", "start", "end", "body_start", "body_end", "parent", "qual_name")
+
+ def __init__(
+ self,
+ kind: str,
+ name: str,
+ start: int,
+ end: int,
+ body_start: int,
+ body_end: int,
+ parent: int,
+ qual_name: str,
+ ) -> None:
+ self.kind = kind
+ self.name = name
+ self.start = start
+ self.end = end
+ self.body_start = body_start
+ self.body_end = body_end
+ self.parent = parent
+ self.qual_name = qual_name
+
+ class _ClassBodyFrame:
+ __slots__ = ("named_type_idx", "body_depth", "member_start",
+ "is_enum", "enum_constants_done")
+
+ def __init__(self, named_type_idx: int, body_depth: int, member_start: int, is_enum: bool) -> None:
+ self.named_type_idx = named_type_idx # >=0 for named type, -1 for anonymous class body
+ self.body_depth = body_depth # brace_depth inside this body
+ self.member_start = member_start # where next member starts
+
+ self.is_enum = is_enum
+ # For enums: constants come first. Members start only after the first ';' at body depth.
+ self.enum_constants_done = (not is_enum)
+
+ # -------------------------
+ # Low-level helpers (all local)
+ # -------------------------
+ def _is_id_start(c: str) -> bool:
+ return c.isalpha() or c == "_" or c == "$"
+
+ def _is_id_part(c: str) -> bool:
+ return c.isalnum() or c == "_" or c == "$"
+
+ def _skip_ws_and_comments(s: str, i: int, lim: int) -> int:
+ while i < lim:
+ c = s[i]
+ if c.isspace():
+ i += 1
+ continue
+ if c == "/" and i + 1 < lim:
+ n2 = s[i + 1]
+ if n2 == "/": # line comment
+ i += 2
+ while i < lim and s[i] != "\n":
+ i += 1
+ continue
+ if n2 == "*": # block comment
+ j = s.find("*/", i + 2, lim)
+ i = lim if j == -1 else (j + 2)
+ continue
+ break
+ return i
- def _enum_constants_terminator(type_region: _TypeRegion) -> int:
+ def _signature_has_initializer_assignment(sig_start: int, name_start: int) -> bool:
"""
- Find the position of the semicolon that terminates the enum-constants list
- inside an enum body, or return -1 if no explicit terminator exists.
-
- Parameters
- ----------
- type_region : _TypeRegion
- The region describing the enum body. `type_region.start` is the index
- of the opening '{' and `type_region.end` is the index of the matching '}'.
-
- Returns
- -------
- int
- The index of the terminating ';' that appears at *brace depth 1* inside
- the enum body. If the constants section runs directly into the closing
- brace (i.e., no explicit ';'), returns -1. Also returns -1 for
- unterminated/degenerate cases.
+ True if there's an '=' before `name_start` that is NOT inside an annotation's (...) args
+ and not inside comments/strings.
- Caching
- -------
- Results are cached in `enum_term_cache` keyed by `type_region.start`.
+ This is used to reject parsing '(' that belongs to field initializers like:
+ Foo X = new Foo(...) { ... }
"""
- # Cache fast-path
- cached_pos = enum_term_cache.get(type_region.start)
- if cached_pos is not None:
- return cached_pos
-
- scan_pos: int = type_region.start + 1 # first character after '{'
- body_end: int = type_region.end # index of matching '}'
- brace_depth: int = 1 # already inside the enum body
-
- while scan_pos < body_end:
- next_pos = _skip_ws_comments(src, scan_pos)
- if next_pos != scan_pos:
- scan_pos = next_pos
- continue
-
- # Skip over literals so braces/semicolons inside them are ignored.
- if src.startswith('"""', scan_pos) or src[scan_pos] in ('"', "'"):
- scan_pos = _skip_java_string_like(src, scan_pos)
- continue
-
- ch = src[scan_pos]
-
- if ch == '{':
- brace_depth += 1
- elif ch == '}':
- brace_depth -= 1
- if brace_depth == 0:
- # Reached end of enum body without seeing a terminator.
- enum_term_cache[type_region.start] = -1
- return -1
- elif ch == ';' and brace_depth == 1:
- # The terminator separating constants from members.
- enum_term_cache[type_region.start] = scan_pos
- return scan_pos
- elif ch == '(':
- # Skip (...) blocks to avoid miscounting nested tokens.
- scan_pos = _match_balanced_parens(src, scan_pos) + 1
- continue
- elif ch == '<':
- # Skip generic type argument lists inside declarations.
- scan_pos = _match_balanced_angles(src, scan_pos) + 1
- continue
-
- scan_pos += 1
-
- # No explicit terminator found.
- enum_term_cache[type_region.start] = -1
- return -1
-
- def _in_enum_constants_section(name_pos: int) -> bool:
- reg = _enclosing_region(name_pos)
- if not reg or _region_kind(reg) != 'enum':
+ # Fast path: no '=' at all
+ if s.find("=", sig_start, name_start) == -1:
return False
- term = _enum_constants_terminator(reg)
- return term == -1 or name_pos < term
-
- def _scan_paren_internals_for_anonymous_classes(open_idx: int, close_idx: int, out_accum: List[str]):
- # Look for "new Type(...){ ... }" inside argument lists; recurse into their bodies
- k2 = open_idx + 1
- while k2 < close_idx:
- kk = _skip_ws_comments(src, k2)
- if kk != k2:
- k2 = kk; continue
- if k2 < close_idx and (src.startswith('"""', k2) or src[k2] in ('"', "'")):
- k2 = _skip_java_string_like(src, k2); continue
-
- pos = src.find('new', k2, close_idx)
- if pos == -1:
- break
- k2 = pos
-
- # Skip false positives inside ignored spans
- sp = _ignored_span_at(k2) if _DELIM_SRC_ID == id(src) else None
- if sp and sp[0] <= k2 < sp[1]:
- k2 = sp[1]; continue
-
- pre = src[k2 - 1] if k2 > open_idx + 1 else ''
- post = src[k2 + 3] if k2 + 3 < close_idx else ''
- if (k2 == open_idx + 1 or not _is_ident_part(pre)) and (k2 + 3 >= close_idx or not _is_ident_part(post)):
- t = _skip_ws_comments(src, k2 + 3)
- # Qualified type with optional type params
- while t < close_idx:
- if _is_ident_start(src[t]):
- t += 1
- while t < close_idx and _is_ident_part(src[t]): t += 1
- if t < close_idx and src[t] == '.':
- t += 1; continue
- elif t < close_idx and src[t] == '<':
- t = _match_balanced_angles(src, t) + 1
- else:
- break
- t = _skip_ws_comments(src, t)
-
- if t < close_idx and src[t] == '(':
- ctor_close = _match_balanced_parens(src, t)
- u = _skip_ws_comments(src, ctor_close + 1)
- if u < close_idx and src[u] == '{':
- body_end = _match_balanced_braces(src, u)
- if u + 1 <= body_end:
- out_accum.extend(_extract_methods_anywhere(src, start=u + 1, end=body_end, types=types))
- k2 = body_end + 1; continue
- k2 = ctor_close + 1; continue
-
- k2 = k2 + 3 # move past this "new"
-
- def _body_might_have_local_types(beg: int, end_: int) -> bool:
- s = src
- return (s.find('class', beg, end_) != -1 or
- s.find('interface', beg, end_) != -1 or
- s.find('enum', beg, end_) != -1 or
- s.find('record', beg, end_) != -1)
- out: List[str] = []
- i = start
- n = end
- seen_spans: Set[Tuple[int, int]] = set()
+ NORMAL3, LINE3, BLOCK3, STRING3, CHAR3, TEXT3 = 0, 1, 2, 3, 4, 5
+ st = NORMAL3
+ i3 = sig_start
+ while i3 < name_start:
+ c3 = s[i3]
+
+ if st == NORMAL3:
+ if c3 == "/" and i3 + 1 < name_start:
+ n3 = s[i3 + 1]
+ if n3 == "/":
+ st = LINE3
+ i3 += 2
+ continue
+ if n3 == "*":
+ st = BLOCK3
+ i3 += 2
+ continue
- while i < n:
- j = _skip_ws_comments(src, i)
- if j != i:
- i = j; continue
- if i >= n: break
+ if c3 == '"':
+ if i3 + 2 < name_start and s[i3:i3 + 3] == '"""':
+ st = TEXT3
+ i3 += 3
+ continue
+ st = STRING3
+ i3 += 1
+ continue
- if src.startswith('"""', i) or src[i] in ('"', "'"):
- i = _skip_java_string_like(src, i); continue
+ if c3 == "'":
+ st = CHAR3
+ i3 += 1
+ continue
+
+ # Skip an annotation (and its (...) if present) so '=' inside it doesn't count.
+ if c3 == "@":
+ i3 += 1
+ # qualified annotation name
+ while i3 < name_start and (_is_id_part(s[i3]) or s[i3] == "."):
+ i3 += 1
+ i3 = _skip_ws_and_comments(s, i3, name_start)
+ if i3 < name_start and s[i3] == "(":
+ r3 = _match_balanced_paren(s, i3, name_start)
+ if r3 == -1:
+ # malformed; be conservative (don't reject)
+ return False
+ i3 = r3 + 1
+ continue
- if src[i] == '(':
- name_start, name = _back_ident(src, i)
- close_p = _match_balanced_parens(src, i)
+ if c3 == "=":
+ return True
- if not name:
- if close_p < n:
- _scan_paren_internals_for_anonymous_classes(i, close_p, out)
- i = close_p + 1
+ i3 += 1
continue
- # Record headers are NOT methods/constructors; skip them explicitly.
- if _prev_word(src, name_start) == 'record':
- if close_p < n:
- _scan_paren_internals_for_anonymous_classes(i, close_p, out)
- i = close_p + 1
+ if st == LINE3:
+ if c3 == "\n":
+ st = NORMAL3
+ i3 += 1
continue
- # Control-flow invocations like `if (...)`, `for (...)` etc. are not methods.
- if name in _NON_METHOD_TOKENS:
- if close_p < n:
- _scan_paren_internals_for_anonymous_classes(i, close_p, out)
- i = close_p + 1
+ if st == BLOCK3:
+ if c3 == "*" and i3 + 1 < name_start and s[i3 + 1] == "/":
+ st = NORMAL3
+ i3 += 2
+ continue
+ i3 += 1
continue
- left_bound = _find_left_boundary(src, name_start)
-
- # "new" between boundary and name means likely a constructor call or anonymous class literal
- # (i.e., not a declaration). If found, optionally recurse into the anonymous class body.
- def _has_new_token_between(s: str, left: int, right: int) -> bool:
- left = max(0, left); right = min(len(s), right)
- i2 = left
- while True:
- pos = s.find('new', i2, right)
- if pos == -1: return False
- if _DELIM_SRC_ID == id(s):
- sp = _ignored_span_at(pos)
- if sp and sp[0] <= pos < sp[1]:
- i2 = pos + 3; continue
- pre = s[pos - 1] if pos > left else ''
- post = s[pos + 3] if pos + 3 < right else ''
- if (pos == left or not _is_ident_part(pre)) and (pos + 3 >= right or not _is_ident_part(post)):
- return True
- i2 = pos + 3
-
- if _has_new_token_between(src, left_bound, name_start):
- if close_p < n:
- _scan_paren_internals_for_anonymous_classes(i, close_p, out)
- kk = _skip_ws_comments(src, close_p + 1)
- if kk < n and src[kk] == '{':
- # Anonymous inner class: collect methods from its body.
- body_end = _match_balanced_braces(src, kk)
- if kk + 1 <= body_end:
- out.extend(_extract_methods_anywhere(src, start=kk + 1, end=body_end, types=types))
- i = body_end + 1
- else:
- i = close_p + 1
+ if st == STRING3:
+ if c3 == "\\":
+ i3 += 2
+ continue
+ if c3 == '"':
+ st = NORMAL3
+ i3 += 1
continue
- # Enum constant bodies (class bodies inside the constants section) — recurse.
- k = _skip_ws_comments(src, close_p + 1)
- if k < n and src[k] == '{' and _in_enum_constants_section(name_start):
- if close_p < n:
- _scan_paren_internals_for_anonymous_classes(i, close_p, out)
- body_end = _match_balanced_braces(src, k)
- if k + 1 <= body_end:
- out.extend(_extract_methods_anywhere(src, start=k + 1, end=body_end, types=types))
- i = body_end + 1
+ if st == CHAR3:
+ if c3 == "\\":
+ i3 += 2
+ continue
+ if c3 == "'":
+ st = NORMAL3
+ i3 += 1
continue
- # We expect a method/constructor body here.
- k = _skip_throws_clause(src, k)
- if k >= n or src[k] != '{':
- if close_p < n:
- _scan_paren_internals_for_anonymous_classes(i, close_p, out)
- i = close_p + 1
+ # TEXT3
+ if i3 + 2 < name_start and s[i3:i3 + 3] == '"""':
+ st = NORMAL3
+ i3 += 3
continue
+ i3 += 1
- sig_guess = _find_left_boundary(src, name_start)
- sig_start = _skip_leading_annotations(src, _skip_ws_comments(src, sig_guess))
-
- # Constructors are now included: if name == enclosing type name, we still emit it.
- body_end = _match_balanced_braces(src, k)
- sig_start = _skip_ws_comments(src, sig_start) # final guard
+ return False
- header_no_ann = _strip_annotations_in_range(src, sig_start, k)
- clean_header = _strip_comments_keep_strings_range(header_no_ann, 0, len(header_no_ann))
-
- clean_body = _strip_body_comments_and_annotations(src, k, body_end)
+ def _match_balanced_paren(s: str, lparen: int, lim: int) -> int:
+ NORMAL, LINE, BLOCK, STRING, CHAR, TEXT = 0, 1, 2, 3, 4, 5
+ state = NORMAL
+ i = lparen + 1
+ depth = 1
- method_text = clean_header + clean_body
+ while i < lim:
+ c = s[i]
+ if state == NORMAL:
+ if c == "/" and i + 1 < lim:
+ n2 = s[i + 1]
+ if n2 == "/":
+ state = LINE
+ i += 2
+ continue
+ if n2 == "*":
+ state = BLOCK
+ i += 2
+ continue
+ elif c == '"':
+ if i + 2 < lim and s[i: i + 3] == '"""':
+ state = TEXT
+ i += 3
+ continue
+ state = STRING
+ i += 1
+ continue
+ elif c == "'":
+ state = CHAR
+ i += 1
+ continue
+ elif c == "(":
+ depth += 1
+ elif c == ")":
+ depth -= 1
+ if depth == 0:
+ return i
+ elif state == LINE:
+ if c == "\n":
+ state = NORMAL
+ elif state == BLOCK:
+ if c == "*" and i + 1 < lim and s[i + 1] == "/":
+ state = NORMAL
+ i += 2
+ continue
+ elif state == STRING:
+ if c == "\\":
+ i += 2
+ continue
+ if c == '"':
+ state = NORMAL
+ elif state == CHAR:
+ if c == "\\":
+ i += 2
+ continue
+ if c == "'":
+ state = NORMAL
+ else: # TEXT
+ if i + 2 < lim and s[i: i + 3] == '"""':
+ state = NORMAL
+ i += 3
+ continue
+ i += 1
- span = (sig_start, body_end)
- if span not in seen_spans:
- seen_spans.add(span)
- out.append(method_text)
+ return -1
- # Recurse into body for local/anonymous types
- if k + 1 <= body_end and _body_might_have_local_types(k + 1, body_end):
- out.extend(_extract_methods_anywhere(src, start=k + 1, end=body_end, types=types))
+ def _match_balanced_brace(s: str, lbrace: int, lim: int) -> int:
+ NORMAL, LINE, BLOCK, STRING, CHAR, TEXT = 0, 1, 2, 3, 4, 5
+ state = NORMAL
+ i = lbrace + 1
+ depth = 1
- i = body_end + 1
- continue
+ while i < lim:
+ c = s[i]
+ if state == NORMAL:
+ if c == "/" and i + 1 < lim:
+ n2 = s[i + 1]
+ if n2 == "/":
+ state = LINE
+ i += 2
+ continue
+ if n2 == "*":
+ state = BLOCK
+ i += 2
+ continue
+ elif c == '"':
+ if i + 2 < lim and s[i: i + 3] == '"""':
+ state = TEXT
+ i += 3
+ continue
+ state = STRING
+ i += 1
+ continue
+ elif c == "'":
+ state = CHAR
+ i += 1
+ continue
+ elif c == "{":
+ depth += 1
+ elif c == "}":
+ depth -= 1
+ if depth == 0:
+ return i
+ elif state == LINE:
+ if c == "\n":
+ state = NORMAL
+ elif state == BLOCK:
+ if c == "*" and i + 1 < lim and s[i + 1] == "/":
+ state = NORMAL
+ i += 2
+ continue
+ elif state == STRING:
+ if c == "\\":
+ i += 2
+ continue
+ if c == '"':
+ state = NORMAL
+ elif state == CHAR:
+ if c == "\\":
+ i += 2
+ continue
+ if c == "'":
+ state = NORMAL
+ else: # TEXT
+ if i + 2 < lim and s[i: i + 3] == '"""':
+ state = NORMAL
+ i += 3
+ continue
+ i += 1
- # Secondary pattern: brace-first line (signature ended previous line)
- if src[i] == '{':
- p = src.rfind(')', start, i)
- if p != -1:
- open_p = _match_paren_backwards(src, p)
- if open_p > 0 and _header_closes_at_brace(src, p, i):
- name_start, name = _back_ident(src, open_p)
- if name and name not in _NON_METHOD_TOKENS:
- left_bound = _find_left_boundary(src, name_start)
- # Filter out anonymous class/ctor callsites; constructors are INCLUDED.
- def _has_new_between(s, a, b):
- a = max(0, a); b = min(len(s), b)
- pos = s.find('new', a, b)
- while pos != -1:
- sp = _ignored_span_at(pos) if _DELIM_SRC_ID == id(s) else None
- if not sp or not (sp[0] <= pos < sp[1]):
- pre = s[pos - 1] if pos > a else ''
- post = s[pos + 3] if s and (pos + 3) < b else ''
- if (pos == a or not _is_ident_part(pre)) and (pos + 3 >= b or not _is_ident_part(post)):
- return True
- pos = s.find('new', pos + 3, b)
- return False
- if not _has_new_between(src, left_bound, name_start) and _prev_word(src, name_start) != 'record':
- body_end = _match_balanced_braces(src, i)
- sig_guess = _find_left_boundary(src, name_start)
- sig_start = _skip_leading_annotations(src, _skip_ws_comments(src, sig_guess))
- sig_start = _skip_ws_comments(src, sig_start)
+ return -1
+
+ def _line_is_annotation_only(s: str, line_start: int, line_end: int) -> bool:
+ i = line_start
+ while i < line_end and s[i].isspace():
+ i += 1
+ if i >= line_end or s[i] != "@":
+ return False
+
+ # Parse one-or-more annotations: @Name or @pkg.Name or @Name(...)
+ while True:
+ if i >= line_end or s[i] != "@":
+ break
+ i += 1
+ if i >= line_end or not _is_id_start(s[i]):
+ return False
+ i += 1
+ while i < line_end and (_is_id_part(s[i]) or s[i] == "."):
+ i += 1
+
+ i = _skip_ws_and_comments(s, i, line_end)
+ if i < line_end and s[i] == "(":
+ r = _match_balanced_paren(s, i, line_end)
+ if r == -1:
+ return False
+ i = r + 1
+
+ i = _skip_ws_and_comments(s, i, line_end)
+ if i < line_end and s[i] == "@":
+ continue
+ break
+
+ i = _skip_ws_and_comments(s, i, line_end)
+ while i < line_end and s[i].isspace():
+ i += 1
+ return i >= line_end
+
+ def _skip_annotation_only_lines(s: str, i: int, lim: int) -> int:
+ n_local = len(s)
+ while i < lim:
+ line_end = s.find("\n", i, lim)
+ if line_end == -1:
+ line_end = min(lim, n_local)
+ if not _line_is_annotation_only(s, i, line_end):
+ return i
+ i = line_end + 1
+ return i
+
+ # -------------------------
+ # Type discovery (named inner chain computed as qual_name)
+ # -------------------------
+ def _find_all_named_types(s: str) -> list[_TypeRegion]:
+ n_local = len(s)
+ NORMAL, LINE, BLOCK, STRING, CHAR, TEXT = 0, 1, 2, 3, 4, 5
+ state = NORMAL
+
+ class_kinds = {"class", "interface", "enum", "record"}
+
+ out_types: list[_TypeRegion] = []
+ type_stack: list[int] = [] # indices of open named types
+ brace_stack: list[tuple[int, int]] = [] # (open_brace_pos, named_type_idx_or_-1)
+
+ pending_kind: str | None = None
+ pending_start: int = -1
+ pending_name: str | None = None
+
+ def _prev_non_ws_local(pos: int) -> int:
+ while pos >= 0 and s[pos].isspace():
+ pos -= 1
+ return pos
+
+ i = 0
+ while i < n_local:
+ c = s[i]
+
+ if state == NORMAL:
+ if c == "/" and i + 1 < n_local:
+ n2 = s[i + 1]
+ if n2 == "/":
+ state = LINE
+ i += 2
+ continue
+ if n2 == "*":
+ state = BLOCK
+ i += 2
+ continue
+ if c == '"':
+ if i + 2 < n_local and s[i: i + 3] == '"""':
+ state = TEXT
+ i += 3
+ continue
+ state = STRING
+ i += 1
+ continue
+ if c == "'":
+ state = CHAR
+ i += 1
+ continue
+
+ if _is_id_start(c):
+ j = i + 1
+ while j < n_local and _is_id_part(s[j]):
+ j += 1
+ tok = s[i:j]
+
+ if pending_kind is None:
+ if tok in class_kinds:
+ # ------------------------------
+ # ✅ CRITICAL GUARD:
+ # Ignore ".class" and other member-access uses:
+ # StandardScheme.class.equals(...)
+ # Foo.class
+ # This prevents bogus "class equals" types.
+ # ------------------------------
+ p = _prev_non_ws_local(i - 1)
+ if p >= 0 and s[p] == ".":
+ i = j
+ continue
+ # Defensive: ignore "class." patterns too
+ if j < n_local and s[j] == ".":
+ i = j
+ continue
+
+ pending_kind = tok
+ pending_start = i
+ pending_name = None
+ else:
+ if pending_name is None:
+ pending_name = tok
+
+ i = j
+ continue
+
+ if c == "{":
+ brace_pos = i
+ if pending_kind is not None and pending_name is not None:
+ parent = type_stack[-1] if type_stack else -1
+ qual = pending_name if parent == -1 else (out_types[parent].qual_name + "." + pending_name)
+ out_types.append(
+ _TypeRegion(
+ kind=pending_kind,
+ name=pending_name,
+ start=pending_start,
+ end=-1,
+ body_start=brace_pos + 1,
+ body_end=-1,
+ parent=parent,
+ qual_name=qual,
+ )
+ )
+ t_idx = len(out_types) - 1
+ brace_stack.append((brace_pos, t_idx))
+ type_stack.append(t_idx)
+ pending_kind = None
+ pending_start = -1
+ pending_name = None
+ else:
+ brace_stack.append((brace_pos, -1))
+ i += 1
+ continue
+
+ if c == "}":
+ if brace_stack:
+ _, t_idx = brace_stack.pop()
+ if t_idx != -1:
+ out_types[t_idx].body_end = i
+ out_types[t_idx].end = i + 1
+ if type_stack and type_stack[-1] == t_idx:
+ type_stack.pop()
+ else:
+ # defensive fallback
+ try:
+ type_stack.remove(t_idx)
+ except ValueError:
+ pass
+ i += 1
+ continue
+
+ i += 1
+ continue
+
+ if state == LINE:
+ if c == "\n":
+ state = NORMAL
+ i += 1
+ continue
+
+ if state == BLOCK:
+ if c == "*" and i + 1 < n_local and s[i + 1] == "/":
+ state = NORMAL
+ i += 2
+ continue
+ i += 1
+ continue
+
+ if state == STRING:
+ if c == "\\":
+ i += 2
+ continue
+ if c == '"':
+ state = NORMAL
+ i += 1
+ continue
+
+ if state == CHAR:
+ if c == "\\":
+ i += 2
+ continue
+ if c == "'":
+ state = NORMAL
+ i += 1
+ continue
+
+ # TEXT
+ if i + 2 < n_local and s[i: i + 3] == '"""':
+ state = NORMAL
+ i += 3
+ continue
+ i += 1
+
+ # Keep only closed types
+ closed: list[_TypeRegion] = []
+ for t in out_types:
+ if t.end != -1 and t.body_end != -1:
+ closed.append(t)
+ return closed
+
+ def _normalize_types(maybe_types: list) -> list[_TypeRegion]:
+ norm: list[_TypeRegion] = []
+ for t in maybe_types:
+ # attribute-style
+ if hasattr(t, "kind") and hasattr(t, "name") and hasattr(t, "body_start") and hasattr(t, "body_end"):
+ kind = getattr(t, "kind")
+ name = getattr(t, "name")
+ tstart = getattr(t, "start", -1)
+ tend = getattr(t, "end", -1)
+ bstart = getattr(t, "body_start")
+ bend = getattr(t, "body_end")
+ parent = getattr(t, "parent", -1)
+ qual = getattr(t, "qual_name", "") or name
+ norm.append(_TypeRegion(kind, name, tstart, tend, bstart, bend, parent, qual))
+ continue
+
+ # dict-style
+ if isinstance(t, dict):
+ kind = t.get("kind", "")
+ name = t.get("name", "")
+ tstart = t.get("start", -1)
+ tend = t.get("end", -1)
+ bstart = t.get("body_start", -1)
+ bend = t.get("body_end", -1)
+ parent = t.get("parent", -1)
+ qual = t.get("qual_name") or name
+ norm.append(_TypeRegion(kind, name, tstart, tend, bstart, bend, parent, qual))
+ continue
+
+ # tuple/list positional fallback (kind,name,start,end,body_start,body_end,parent,qual_name)
+ if isinstance(t, (tuple, list)) and len(t) >= 6:
+ kind = t[0]
+ name = t[1]
+ tstart = t[2] if len(t) > 2 else -1
+ tend = t[3] if len(t) > 3 else -1
+ bstart = t[4]
+ bend = t[5]
+ parent = t[6] if len(t) > 6 else -1
+ qual = t[7] if len(t) > 7 else name
+ norm.append(_TypeRegion(kind, name, tstart, tend, bstart, bend, parent, qual))
+ continue
+
+ # unknown shape -> ignore
+ return norm
+
+ def _repair_qual_names(types_list: list[_TypeRegion]) -> None:
+ """
+ If caller-provided `types` lack correct qual_name chains, repair them using `parent`.
+ Safe no-op if already correct.
+ """
+ any_needs = False
+ for t in types_list:
+ if t.parent != -1 and ("." not in (t.qual_name or "")):
+ any_needs = True
+ break
+ if not any_needs:
+ return
+
+ memo: dict[int, str] = {}
+
+ def qual(i: int) -> str:
+ if i in memo:
+ return memo[i]
+ t = types_list[i]
+ if t.parent == -1:
+ memo[i] = t.name
+ else:
+ p = qual(t.parent)
+ memo[i] = p + "." + t.name if p else t.name
+ return memo[i]
+
+ for idx in range(len(types_list)):
+ types_list[idx].qual_name = qual(idx)
+
+ # -------------------------
+ # Main extraction logic
+ # -------------------------
+ s = src
+ n = len(s)
+ if end is None:
+ end = n
+ if start < 0:
+ start = 0
+ if end > n:
+ end = n
+ if start >= end:
+ return []
+
+ if types is None:
+ types2 = _find_all_named_types(s)
+ else:
+ types2 = _normalize_types(types)
+ _repair_qual_names(types2)
+
+ # Map opening brace positions to named type index.
+ # Named type opening brace is at (body_start - 1).
+ lbrace_to_type: dict[int, int] = {}
+ for idx, t in enumerate(types2):
+ lbrace_to_type[t.body_start - 1] = idx
+
+ # Precompute "inner-type" comment payload per named type idx.
+ # We strip the top-level declaring type (first segment) and keep the rest.
+ inner_comment_by_idx: list[str | None] = [None] * len(types2)
+ for idx, t in enumerate(types2):
+ if t.parent != -1:
+ q = t.qual_name or t.name
+ dot = q.find(".")
+ inner_comment_by_idx[idx] = q[dot + 1:] if dot != -1 else t.name
+
+ # Scanner state
+ NORMAL, LINE, BLOCK, STRING, CHAR, TEXT = 0, 1, 2, 3, 4, 5
+ state = NORMAL
+
+ out: list[str] = []
+
+ # Brace stack kinds: 0 = other, 1 = named type body, 2 = anonymous class body
+ brace_stack: list[tuple[int, int]] = []
+ brace_depth = 0
+
+ # Track parentheses and whether they belong to a `new ... ( ... )` object creation
+ paren_objcreate_stack: list[bool] = []
+ in_new_expr = False # True from seeing token "new" until we consume its '(' or cancel
+
+ expecting_anon_class_brace = False
+ class_bodies: list[_ClassBodyFrame] = []
+
+ modifiers = {
+ "public", "protected", "private", "abstract", "final", "static", "native",
+ "synchronized", "strictfp", "default", "transient", "volatile", "sealed", "non-sealed",
+ }
+ ctrl_like = {
+ "if", "for", "while", "switch", "catch", "do", "try", "return", "throw", "new",
+ "case", "assert", "break", "continue", "else", "finally", "yield",
+ "this", "super",
+ }
+
+ def _current_named_type_name() -> str | None:
+ if not class_bodies:
+ return None
+ idx = class_bodies[-1].named_type_idx
+ if idx < 0:
+ return None
+ return types2[idx].name
+
+ def _current_inner_comment() -> str | None:
+ if not class_bodies:
+ return None
+ idx = class_bodies[-1].named_type_idx
+ if idx < 0:
+ return None
+ return inner_comment_by_idx[idx]
+
+ def _is_constructor_allowed(method_name: str) -> bool:
+ tn = _current_named_type_name()
+ return tn is not None and method_name == tn
+
+ def _prev_non_ws(pos: int, floor: int) -> int:
+ j = pos
+ while j >= floor and s[j].isspace():
+ j -= 1
+ return j
+
+ def _read_prev_identifier(pos: int, floor: int) -> tuple[str, int, int]:
+ j = _prev_non_ws(pos, floor)
+ if j < floor:
+ return "", -1, -1
+ if not _is_id_part(s[j]):
+ return "", -1, -1
+ k = j
+ while k >= floor and _is_id_part(s[k]):
+ k -= 1
+ return s[k + 1: j + 1], k + 1, j + 1
+
+ def _has_return_type_before_name(name_start: int, floor: int) -> bool:
+ j = name_start - 1
+ while j >= floor:
+ while j >= floor and s[j].isspace():
+ j -= 1
+ if j < floor:
+ break
+ if not _is_id_part(s[j]):
+ j -= 1
+ continue
+
+ tok, ts, _ = _read_prev_identifier(j, floor)
+ if not tok:
+ break
+
+ # skip annotation identifiers (@Something)
+ p = _prev_non_ws(ts - 1, floor)
+ if p >= floor and s[p] == "@":
+ j = p - 1
+ continue
+
+ if tok in modifiers:
+ j = ts - 1
+ continue
+
+ # ✅ IMPORTANT: keywords like "new"/"return"/"if"/... are NOT return types
+ if tok in ctrl_like:
+ j = ts - 1
+ continue
+
+ return True
+ return False
+
+
+ def _try_parse_method_at_lparen(lparen: int) -> tuple[int, int] | None:
+ frame = class_bodies[-1]
+ member_floor = frame.member_start
+
+ # method/ctor name token before '('
+ j = _prev_non_ws(lparen - 1, member_floor)
+ if j < member_floor or not _is_id_part(s[j]):
+ return None
+
+ k = j
+ while k >= member_floor and _is_id_part(s[k]):
+ k -= 1
+ name_start = k + 1
+ name = s[name_start: j + 1]
+ if not name or name in ctrl_like:
+ return None
+
+ # ------------------------------------------------------------
+ # ✅ NEW: Reject annotation argument parens: "@Something(...)".
+ # This prevents "@SuppressWarnings(...)" from being mis-parsed as a method.
+ # ------------------------------------------------------------
+ p0 = _prev_non_ws(name_start - 1, member_floor)
+ if p0 >= member_floor and s[p0] == "@":
+ return None
+
+ # Reject qualified/call expressions: ".name(" or "::name("
+ if p0 >= member_floor and s[p0] in (".", ":"):
+ return None
+
+ # Compute the "real signature start" (skip ws/comments + annotation-only lines)
+ sig0 = _skip_ws_and_comments(s, member_floor, end)
+ sig0 = _skip_annotation_only_lines(s, sig0, min(lparen, end))
+
+ # ------------------------------------------------------------
+ # ✅ Field initializer guard:
+ # If there's an '=' before this '(' in the same member signature,
+ # we're in an initializer, not a method/ctor declaration.
+ # ------------------------------------------------------------
+ if sig0 < lparen and s.find("=", sig0, lparen) != -1:
+ if _signature_has_initializer_assignment(sig0, lparen):
+ return None
+
+ # Reject type declarations that have (...) headers (e.g., "record R(...){...}")
+ prev_tok, _, _ = _read_prev_identifier(name_start - 1, member_floor)
+ if prev_tok in ("class", "interface", "enum", "record"):
+ return None
+
+ # Reject "new Name(" (constructor call) vs declaration
+ if prev_tok == "new":
+ return None
+
+ rparen = _match_balanced_paren(s, lparen, end)
+ if rparen == -1:
+ return None
+
+ # Find '{' starting the body (skip ws/comments; allow throws etc)
+ i2 = _skip_ws_and_comments(s, rparen + 1, end)
+ if i2 >= end:
+ return None
+ if i2 + 1 < end and s[i2: i2 + 2] == "->":
+ return None
+ if s[i2] == ";":
+ return None
+
+ # Scan forward until '{' or ';' or '->', respecting strings/comments.
+ NORMAL2, LINE2, BLOCK2, STRING2, CHAR2, TEXT2 = 0, 1, 2, 3, 4, 5
+ st2 = NORMAL2
+ j2 = i2
+ body_lbrace = -1
+
+ while j2 < end:
+ c2 = s[j2]
+ if st2 == NORMAL2:
+ if c2 == "/" and j2 + 1 < end:
+ n2 = s[j2 + 1]
+ if n2 == "/":
+ st2 = LINE2
+ j2 += 2
+ continue
+ if n2 == "*":
+ st2 = BLOCK2
+ j2 += 2
+ continue
+ if c2 == '"':
+ if j2 + 2 < end and s[j2: j2 + 3] == '"""':
+ st2 = TEXT2
+ j2 += 3
+ continue
+ st2 = STRING2
+ j2 += 1
+ continue
+ if c2 == "'":
+ st2 = CHAR2
+ j2 += 1
+ continue
+
+ if j2 + 1 < end and s[j2: j2 + 2] == "->":
+ return None
+ if c2 == ";":
+ return None
+ if c2 == "{":
+ body_lbrace = j2
+ break
+
+ elif st2 == LINE2:
+ if c2 == "\n":
+ st2 = NORMAL2
+ elif st2 == BLOCK2:
+ if c2 == "*" and j2 + 1 < end and s[j2 + 1] == "/":
+ st2 = NORMAL2
+ j2 += 2
+ continue
+ elif st2 == STRING2:
+ if c2 == "\\":
+ j2 += 2
+ continue
+ if c2 == '"':
+ st2 = NORMAL2
+ elif st2 == CHAR2:
+ if c2 == "\\":
+ j2 += 2
+ continue
+ if c2 == "'":
+ st2 = NORMAL2
+ else: # TEXT2
+ if j2 + 2 < end and s[j2: j2 + 3] == '"""':
+ st2 = NORMAL2
+ j2 += 3
+ continue
+
+ j2 += 1
+
+ if body_lbrace == -1:
+ return None
- header_no_ann = _strip_annotations_in_range(src, sig_start, i)
- clean_header = _strip_comments_keep_strings_range(header_no_ann, 0, len(header_no_ann))
+ # ------------------------------------------------------------
+ # ✅ IMPORTANT CHANGE: use sig0 as the "floor" for return-type detection.
+ # This prevents Javadoc/comments/previous-members from faking a return type.
+ # ------------------------------------------------------------
+ if frame.named_type_idx >= 0:
+ if not _has_return_type_before_name(name_start, sig0):
+ if not _is_constructor_allowed(name):
+ return None
+ else:
+ if not _has_return_type_before_name(name_start, sig0):
+ return None
+
+ rbrace = _match_balanced_brace(s, body_lbrace, end)
+ if rbrace == -1:
+ return None
+ method_end = rbrace + 1
+
+ # Signature start excluding annotation-only lines
+ sig_start = _skip_annotation_only_lines(s, sig0, min(body_lbrace, end))
+ return sig_start, method_end
+
+ i = start
+ while i < end:
+ c = s[i]
+
+ if state == NORMAL:
+ # comment/string/text transitions
+ if c == "/" and i + 1 < end:
+ n2 = s[i + 1]
+ if n2 == "/":
+ state = LINE
+ i += 2
+ continue
+ if n2 == "*":
+ state = BLOCK
+ i += 2
+ continue
- clean_body = _strip_body_comments_and_annotations(src, i, body_end)
+ if c == '"':
+ if i + 2 < end and s[i: i + 3] == '"""':
+ state = TEXT
+ i += 3
+ continue
+ state = STRING
+ i += 1
+ continue
- method_text = clean_header + clean_body
+ if c == "'":
+ state = CHAR
+ i += 1
+ continue
- span = (sig_start, body_end)
- if span not in seen_spans:
- seen_spans.add(span)
- out.append(method_text)
+ # identifiers: detect 'new'
+ if _is_id_start(c):
+ j = i + 1
+ while j < end and _is_id_part(s[j]):
+ j += 1
+ tok = s[i:j]
+ if tok == "new":
+ in_new_expr = True
+ i = j
+ continue
- if i + 1 <= body_end and _body_might_have_local_types(i + 1, body_end):
- out.extend(_extract_methods_anywhere(src, start=i + 1, end=body_end, types=types))
- i = body_end + 1
+ # cancel `new` expectation on array creation or obvious expression terminators
+ if in_new_expr and c in ("[", "{", ";"):
+ in_new_expr = False
+
+ # Cancel expecting anon brace if we hit meaningful char that's not '{'
+ if expecting_anon_class_brace and not c.isspace() and c != "{":
+ expecting_anon_class_brace = False
+
+ if c == "(":
+ paren_objcreate_stack.append(in_new_expr)
+ in_new_expr = False # consumed the object-creation marker for this '('
+
+ # method parsing at class member level
+ if class_bodies:
+ frame = class_bodies[-1]
+ # Only extract methods for *named* types; skip anonymous-class bodies (new X() { ... }).
+ if frame.named_type_idx >= 0 and brace_depth == frame.body_depth and frame.enum_constants_done:
+ parsed = _try_parse_method_at_lparen(i)
+ if parsed is not None:
+ sig_start, method_end = parsed
+ method_src = s[sig_start:method_end]
+
+ # Strip comments/javadocs/annotations/package/imports FIRST
+ method_src = _strip_java_noise(method_src)
+
+ # Then append the inner-type marker so it survives stripping
+ inner = _current_inner_comment()
+ if inner:
+ method_src = method_src.rstrip() + f"\n/* inner-type: {inner} */"
+
+ out.append(method_src)
+ frame.member_start = method_end
+ i = method_end
continue
+
+ i += 1
+ continue
+
+ if c == ")":
+ if paren_objcreate_stack and paren_objcreate_stack.pop():
+ expecting_anon_class_brace = True
+ i += 1
+ continue
+
+ if c == "{":
+ brace_depth += 1
+
+ t_idx = lbrace_to_type.get(i, -1)
+ if t_idx != -1:
+ brace_stack.append((i, 1))
+ is_enum = (types2[t_idx].kind == "enum")
+ class_bodies.append(_ClassBodyFrame(t_idx, brace_depth, i + 1, is_enum))
+ expecting_anon_class_brace = False
+ in_new_expr = False
+ i += 1
+ continue
+
+ if expecting_anon_class_brace:
+ brace_stack.append((i, 2))
+ class_bodies.append(_ClassBodyFrame(-1, brace_depth, i + 1, False))
+ expecting_anon_class_brace = False
+ in_new_expr = False
+ i += 1
+ continue
+
+ brace_stack.append((i, 0))
+ in_new_expr = False
+ i += 1
+ continue
+
+ if c == "}":
+ if brace_stack:
+ _, kind = brace_stack.pop()
+
+ # if closing a block directly inside a class body, advance member_start
+ if class_bodies:
+ frame = class_bodies[-1]
+ if kind == 0 and brace_depth == frame.body_depth + 1:
+ frame.member_start = i + 1
+
+ if kind == 1:
+ # closing named type: pop frame, and advance parent's member_start
+ if class_bodies and class_bodies[-1].named_type_idx != -1:
+ class_bodies.pop()
+ if class_bodies:
+ class_bodies[-1].member_start = i + 1
+ elif kind == 2:
+ # closing anonymous class: pop frame only
+ if class_bodies and class_bodies[-1].named_type_idx == -1:
+ class_bodies.pop()
+
+ brace_depth = max(0, brace_depth - 1)
+ expecting_anon_class_brace = False
+ in_new_expr = False
+ i += 1
+ continue
+
+ if c == ";" and class_bodies:
+ frame = class_bodies[-1]
+ if brace_depth == frame.body_depth:
+ # Enum constant section ends at the first semicolon at enum-body depth.
+ if frame.is_enum and not frame.enum_constants_done:
+ frame.enum_constants_done = True
+ frame.member_start = i + 1
+ i += 1
+ continue
+
+ i += 1
+ continue
+
+ if state == LINE:
+ if c == "\n":
+ state = NORMAL
+ i += 1
+ continue
+
+ if state == BLOCK:
+ if c == "*" and i + 1 < end and s[i + 1] == "/":
+ state = NORMAL
+ i += 2
+ continue
+ i += 1
+ continue
+
+ if state == STRING:
+ if c == "\\":
+ i += 2
+ continue
+ if c == '"':
+ state = NORMAL
+ i += 1
+ continue
+
+ if state == CHAR:
+ if c == "\\":
+ i += 2
+ continue
+ if c == "'":
+ state = NORMAL
+ i += 1
+ continue
+
+ # TEXT
+ if i + 2 < end and s[i: i + 3] == '"""':
+ state = NORMAL
+ i += 3
+ continue
i += 1
return out
@@ -1647,383 +2643,610 @@ def _has_new_between(s, a, b):
def extract_methods(java_source: str) -> List[str]:
"""
- Returns a list of Java method bodies/signatures (with bodies) plus all arrow-lambda
- forms found in the file. Excludes constructors, comments/Javadoc, and strips
- any method-level annotations that start before the signature token.
+ Returns a list of Java method/constructor bodies/signatures (with bodies) plus all arrow-lambda
+ forms found in the file. Original comments/Javadoc are excluded and method-level annotations are
+ stripped during extraction. For methods/constructors declared inside a *nested named type*, a small synthetic
+ trailing comment is appended:
+ /* inner-type: */
+ (Top-level type members are unchanged.)
Order: methods first, then lambdas. Optimized for large files by pre-indexing
comments/strings and delimiter pairs (O(1) matching for (), {}).
+
+ Additional filtering:
+ - Exclude only the *standard* Object overrides:
+ * boolean equals(Object)
+ * int hashCode()
+ * String toString()
+ (Overloads are kept.)
"""
+ def _is_ident_char(ch: str) -> bool:
+ return ch.isalnum() or ch in ('_', '$')
+
+ def _read_ident_fwd(s: str, pos: int, end: int) -> Tuple[int, str]:
+ if pos >= end or not _is_ident_char(s[pos]):
+ return pos, ""
+ start = pos
+ pos += 1
+ while pos < end and _is_ident_char(s[pos]):
+ pos += 1
+ return pos, s[start:pos]
+
+ def _read_qualified_ident_fwd(s: str, pos: int, end: int) -> Tuple[int, str]:
+ pos, seg = _read_ident_fwd(s, pos, end)
+ if not seg:
+ return pos, ""
+ parts = [seg]
+ while pos < end and s[pos] == '.':
+ dot = pos
+ pos += 1
+ pos2, seg2 = _read_ident_fwd(s, pos, end)
+ if not seg2:
+ # not actually qualified; stop at dot
+ return dot, ".".join(parts)
+ parts.append(seg2)
+ pos = pos2
+ return pos, ".".join(parts)
+
+ def _skip_ws(s: str, pos: int, end: int) -> int:
+ while pos < end and s[pos].isspace():
+ pos += 1
+ return pos
+
+ def _has_top_level_comma(params: str) -> bool:
+ # track generic nesting so commas inside Map don't look like param separators
+ depth = 0
+ for ch in params:
+ if ch == '<':
+ depth += 1
+ elif ch == '>' and depth > 0:
+ depth -= 1
+ elif ch == ',' and depth == 0:
+ return True
+ return False
+
+ def _standard_override_kind(method_text: str) -> str:
+ """
+ Return "equals" / "hashCode" / "toString" if this is exactly the standard Object override,
+ else "".
+
+ Assumptions: method_text is already stripped of comments/Javadoc/annotations by extraction.
+ """
+ open_paren = method_text.find('(')
+ if open_paren == -1:
+ return ""
+
+ # Find the first '{' after '('; that's the method/ctor body start in extracted text.
+ brace = method_text.find('{', open_paren)
+ if brace == -1:
+ return ""
+
+ # Close paren must be before the body.
+ close_paren = method_text.rfind(')', open_paren, brace)
+ if close_paren == -1:
+ return ""
+
+ # Method name: scan left from '('
+ j = open_paren - 1
+ while j >= 0 and method_text[j].isspace():
+ j -= 1
+ name_end = j + 1
+ while j >= 0 and _is_ident_char(method_text[j]):
+ j -= 1
+ name_start = j + 1
+ method_name = method_text[name_start:name_end]
+ if method_name not in ("equals", "hashCode", "toString"):
+ return ""
+
+ # Return type token: scan left from method_name_start
+ k = name_start - 1
+ while k >= 0 and method_text[k].isspace():
+ k -= 1
+ rt_end = k + 1
+ while k >= 0 and (method_text[k].isalnum() or method_text[k] in ('_', '$', '.')):
+ k -= 1
+ return_type = method_text[k + 1:rt_end]
+
+ params = method_text[open_paren + 1:close_paren]
+
+ # ---- hashCode() ----
+ if method_name == "hashCode":
+ if return_type != "int":
+ return ""
+ if params.strip():
+ return ""
+ return "hashCode"
+
+ # ---- toString() ----
+ if method_name == "toString":
+ if return_type not in ("String", "java.lang.String"):
+ return ""
+ if params.strip():
+ return ""
+ return "toString"
+
+ # ---- equals(Object) ----
+ if method_name == "equals":
+ if return_type != "boolean":
+ return ""
+ p = params.strip()
+ if not p:
+ return ""
+ if _has_top_level_comma(p):
+ return ""
+
+ # parse: [final] (Object|java.lang.Object)
+ tmp = p
+ pos = 0
+ end = len(tmp)
+
+ pos = _skip_ws(tmp, pos, end)
+ pos2, maybe_final = _read_ident_fwd(tmp, pos, end)
+ if maybe_final == "final":
+ pos = _skip_ws(tmp, pos2, end)
+
+ pos, type_name = _read_qualified_ident_fwd(tmp, pos, end)
+ if type_name not in ("Object", "java.lang.Object"):
+ return ""
+
+ pos = _skip_ws(tmp, pos, end)
+
+ # Reject arrays/varargs/generic suffixes on the parameter type (not the standard override)
+ if pos < end and (tmp[pos] == '[' or tmp[pos] == '<' or tmp.startswith("...", pos)):
+ return ""
+
+ pos, param_name = _read_ident_fwd(tmp, pos, end)
+ if not param_name:
+ return ""
+
+ # nothing else meaningful should follow in a normal single-param decl
+ if tmp[_skip_ws(tmp, pos, end):].strip():
+ return ""
+
+ return "equals"
+
+ return ""
+
_populate_delim_index_cache(java_source)
try:
- # Extract
methods = _extract_methods_anywhere(java_source, 0, len(java_source))
- # Lambdas in a separate pass across the whole file (so we don't double-count)
+
+ # Filter only the standard Object overrides (keep overloads).
+ filtered: List[str] = [m for m in methods if not _standard_override_kind(m)]
+
lambdas = _extract_lambdas_in_range(java_source, 0, len(java_source))
- return methods + lambdas
+ return filtered + lambdas
finally:
_clear_delim_index_cache()
-def extract_inner_classes(src: str) -> List[str]:
- """
- Return a list of source slices for *inner* named types (class/interface/enum/record),
- including nested inner types. Each returned slice is prefixed with the file's
- top-level import statements (including 'import static ...;').
+def extract_inner_classes(src: str) -> list[str]:
"""
+ Extract Java *named* inner classes/interfaces/enums/records (with full bodies)
+ from a compilation unit text.
- # -----------------------
- # Tiny local "utilities"
- # -----------------------
- def _is_ident_char(ch):
- return ch.isalnum() or ch in ('_', '$')
+ Changes vs prior behavior:
+ - Does NOT prefix extracted inner types with the file's package/import block.
+ - Strips from each extracted inner type snippet:
+ * // line comments
+ * /* ... */ block comments (incl. Javadocs)
+ * annotations (e.g. @SuppressWarnings(...), @Nullable, @A.B)
+ while preserving strings/chars/text blocks.
+
+ Notes:
+ - Anonymous inner classes are NOT returned here (TODO remains).
+ - Performance: single forward scan + O(k) balanced-brace matches for each found inner type.
+ Stripping is O(m) per extracted snippet and is skipped when markers are absent.
+
+ Bugfix:
+ - Avoids false positives from identifiers like `record` used as a variable name
+ (e.g. `Object record = ...`) by only treating `class|interface|enum|record`
+ as declaration keywords at a declaration-start position (after `{`, `}`, `;`, start,
+ or after modifiers/annotations), and by cancelling pending candidates on `=` / `;`
+ (and on `(` for non-record kinds).
+ - Fixes slice start for inner types so we never “pull in” the previous statement line
+ (e.g. enum constants ending with `;`) when computing the snippet start.
+ """
+ s = src
+ n = len(s)
+ if n == 0:
+ return []
+
+ # Quick reject
+ if ("class" not in s and "interface" not in s and "enum" not in s and "record" not in s):
+ return []
+
+ # ---------- local helpers ----------
+ def _is_ident_start(ch: str) -> bool:
+ return ch.isalpha() or ch in "_$"
+
+ def _is_ident_part(ch: str) -> bool:
+ return ch.isalnum() or ch in "_$"
+
+ def _match_balanced_brace(lbrace: int, lim: int) -> int:
+ NORMAL, LINE, BLOCK, STRING, CHAR, TEXT = 0, 1, 2, 3, 4, 5
+ st = NORMAL
+ i = lbrace + 1
+ depth = 1
- def _skip_string(s, i):
- i += 1
- n = len(s)
- while i < n:
+ while i < lim:
c = s[i]
- if c == '\\':
- i += 2
- elif c == '"':
- return i + 1
- else:
- i += 1
- return n
+ if st == NORMAL:
+ if c == "/" and i + 1 < lim:
+ n2 = s[i + 1]
+ if n2 == "/":
+ st = LINE
+ i += 2
+ continue
+ if n2 == "*":
+ st = BLOCK
+ i += 2
+ continue
+ if c == '"':
+ if i + 2 < lim and s[i:i + 3] == '"""':
+ st = TEXT
+ i += 3
+ continue
+ st = STRING
+ i += 1
+ continue
+ if c == "'":
+ st = CHAR
+ i += 1
+ continue
- def _skip_char(s, i):
- i += 1
- n = len(s)
- while i < n:
- c = s[i]
- if c == '\\':
- i += 2
- elif c == "'":
- return i + 1
- else:
- i += 1
- return n
+ if c == "{":
+ depth += 1
+ elif c == "}":
+ depth -= 1
+ if depth == 0:
+ return i
+
+ elif st == LINE:
+ if c == "\n":
+ st = NORMAL
+ elif st == BLOCK:
+ if c == "*" and i + 1 < lim and s[i + 1] == "/":
+ st = NORMAL
+ i += 2
+ continue
+ elif st == STRING:
+ if c == "\\":
+ i += 2
+ continue
+ if c == '"':
+ st = NORMAL
+ elif st == CHAR:
+ if c == "\\":
+ i += 2
+ continue
+ if c == "'":
+ st = NORMAL
+ else: # TEXT
+ if i + 2 < lim and s[i:i + 3] == '"""':
+ st = NORMAL
+ i += 3
+ continue
- def _skip_ws_comments(s, i):
- """Skip whitespace and //... or /*...*/ comments. Always makes progress or returns."""
- n = len(s)
- while i < n:
- start = i
- # whitespace
- while i < n and s[i].isspace():
- i += 1
- # // line comment
- if i + 1 < n and s[i] == '/' and s[i + 1] == '/':
- i += 2
- while i < n and s[i] != '\n':
- i += 1
- continue
- # /* block comment */
- if i + 1 < n and s[i] == '/' and s[i + 1] == '*':
- i += 2
- end = s.find('*/', i)
- if end == -1:
- return n
- i = end + 2
- continue
- if i == start:
- break
- return i
+ i += 1
- def _read_ident(s, i):
- j, n = i, len(s)
- while j < n and _is_ident_char(s[j]):
- j += 1
- return s[i:j], j
+ return -1
- def _word_at(s, i, word):
- n = len(s)
- end = i + len(word)
- if end > n or s[i:end] != word:
+ def _line_is_annotation_only(line_start: int, line_end: int) -> bool:
+ """
+ True if the line contains only annotations (possibly multiple) plus whitespace/comments.
+ """
+ i = line_start
+ while i < line_end and s[i].isspace():
+ i += 1
+ if i >= line_end or s[i] != "@":
return False
- pre = s[i - 1] if i > 0 else ''
- post = s[end] if end < n else ''
- return (not _is_ident_char(pre)) and (not _is_ident_char(post))
-
- def _skip_parenthesized(s, i):
- n = len(s)
- if i >= n or s[i] != '(':
- return i
- depth = 1
- i += 1
- while i < n and depth > 0:
- i = _skip_ws_comments(s, i)
- if i >= n:
+
+ # Parse one-or-more annotations: @Name or @pkg.Name or @Name(...)
+ while True:
+ if i >= line_end or s[i] != "@":
break
- c = s[i]
- if c == '"':
- i = _skip_string(s, i); continue
- if c == "'":
- i = _skip_char(s, i); continue
- if c == '(':
- depth += 1; i += 1; continue
- if c == ')':
- depth -= 1; i += 1; continue
i += 1
- return i
+ if i >= line_end or not _is_ident_start(s[i]):
+ return False
+ i += 1
+ while i < line_end and (_is_ident_part(s[i]) or s[i] == "."):
+ i += 1
- def _skip_leading_annotations(s, i):
- """Skip @Anno, @pkg.Anno, @Anno(...), stacked."""
- n = len(s)
- while True:
i = _skip_ws_comments(s, i)
- if i >= n or s[i] != '@':
- return i
+ if i < line_end and s[i] == "(":
+ r = _match_balanced_parens(s, i)
+ if r == -1:
+ return False
+ i = r + 1
+
+ i = _skip_ws_comments(s, i)
+ if i < line_end and s[i] == "@":
+ continue
+ break
+
+ i = _skip_ws_comments(s, i)
+ while i < line_end and s[i].isspace():
i += 1
+ return i >= line_end
+
+ def _consume_annotations(i: int) -> int:
+ """
+ Consume one-or-more annotations starting at s[i] == '@' and return the first index
+ after them (skipping whitespace/comments between annotations).
+ """
+ while i < n and s[i] == "@":
+ i += 1
+ if i >= n or not _is_ident_start(s[i]):
+ return i
+
# qualified name
- while i < n and (_is_ident_char(s[i]) or s[i] == '.'):
+ i += 1
+ while i < n and (_is_ident_part(s[i]) or s[i] == "."):
i += 1
- i = _skip_ws_comments(s, i)
- if i < n and s[i] == '(':
- i = _skip_parenthesized(s, i)
- # loop to allow many annotations
- def _match_brace(s, i_open):
- n = len(s)
- if i_open >= n or s[i_open] != '{':
- return None
- depth, i = 1, i_open + 1
- steps, limit = 0, 10_000_000
- while i < n and depth > 0:
- steps += 1
- if steps > limit:
- return None
i = _skip_ws_comments(s, i)
- if i >= n:
- break
- c = s[i]
- if c == '"':
- i = _skip_string(s, i); continue
- if c == "'":
- i = _skip_char(s, i); continue
- if c == '{':
- depth += 1; i += 1; continue
- if c == '}':
- depth -= 1
- if depth == 0:
+ if i < n and s[i] == "(":
+ r = _match_balanced_parens(s, i)
+ if r == -1:
return i
- i += 1; continue
- i += 1
- return None
+ i = r + 1
- def _find_left_boundary(s, brace_pos):
+ i = _skip_ws_comments(s, i)
+ return i
+
+ def _sig_start_for_type(kind_pos: int) -> int:
"""
- Find the start position of the type declaration (keyword position).
- Backward bounded search.
+ Start at the beginning of the line containing the kind keyword, and (optionally)
+ include immediately preceding annotation-only lines. Critically, NEVER walk into
+ the previous statement line (e.g., enum constants ending with ';').
"""
- window_size = 4000
- base = max(0, brace_pos - window_size)
- window = s[base:brace_pos]
- best = -1
- for kw in ('class', 'interface', 'enum', 'record'):
- j = window.rfind(kw)
- while j != -1:
- pre = window[j - 1] if j > 0 else ''
- post = window[j + len(kw)] if j + len(kw) < len(window) else ''
- if not _is_ident_char(pre) and not _is_ident_char(post):
- best = max(best, j)
- break
- j = window.rfind(kw, 0, j)
- return base + best if best != -1 else max(0, brace_pos - 1)
-
- # -------------------------------
- # Collect top-level import block
- # -------------------------------
- def _collect_top_level_imports(s):
- imports = []
- NORMAL, LINE, BLOCK, STRING, CHAR = 0, 1, 2, 3, 4
- state = NORMAL
- depth = 0
- i, n = 0, len(s)
+ # Anchor to the declaration line that actually contains `class|interface|enum|record`.
+ sig_line_start = s.rfind("\n", 0, kind_pos) + 1
+ sig_start = sig_line_start
+
+ # Now look upward for annotation-only lines directly above.
+ k = sig_line_start
+ # Trim back over whitespace/newlines to the end of the previous non-empty content.
+ while k > 0 and s[k - 1].isspace():
+ k -= 1
- def _word_at_local(pos, word):
- end = pos + len(word)
- if end > n or s[pos:end] != word:
- return False
- pre = s[pos - 1] if pos > 0 else ''
- post = s[end] if end < n else ''
- return (not _is_ident_char(pre)) and (not _is_ident_char(post))
+ while k > 0:
+ line_start = s.rfind("\n", 0, k - 1) + 1
+ if _line_is_annotation_only(line_start, k):
+ sig_start = line_start
+ k = line_start
+ while k > 0 and s[k - 1].isspace():
+ k -= 1
+ continue
+ break
- saw_top_level_type = False
+ return sig_start
- while i < n and not saw_top_level_type:
- ch = s[i]
- if state == NORMAL:
- if ch == '/':
- if i + 1 < n and s[i + 1] == '/':
- state = LINE; i += 2; continue
- if i + 1 < n and s[i + 1] == '*':
- state = BLOCK; i += 2; continue
- if ch == '"': state = STRING; i += 1; continue
- if ch == "'": state = CHAR; i += 1; continue
-
- if ch == '{':
- depth += 1
- if depth == 1:
- saw_top_level_type = True
- i += 1
- continue
- if ch == '}':
- depth = max(0, depth - 1); i += 1; continue
+ # ---------- scan ----------
+ NORMAL, LINE, BLOCK, STRING, CHAR, TEXT = 0, 1, 2, 3, 4, 5
+ state = NORMAL
- if depth == 0 and ch == 'i' and _word_at_local(i, 'import'):
- j = i + 6
- while j < n and s[j] != ';':
- j += 1
- if j < n and s[j] == ';':
- imports.append(s[i:j + 1].strip())
- i = j + 1
- continue
+ inner_classes: list[str] = []
+
+ class_kinds = {"class", "interface", "enum", "record"}
+ class_mods = {
+ "public", "protected", "private",
+ "abstract", "final", "static", "sealed", "strictfp",
+ }
+ brace_depth = 0
+ named_type_depth_stack: list[int] = []
+
+ pending_kind: str | None = None
+ pending_kind_pos: int = -1
+ pending_name: str | None = None
+
+ # Only True at a plausible declaration start position.
+ can_start_decl = True
+
+ i = 0
+ while i < n:
+ c = s[i]
+
+ if state == NORMAL:
+ # comments
+ if c == "/" and i + 1 < n:
+ n2 = s[i + 1]
+ if n2 == "/":
+ state = LINE
+ i += 2
+ continue
+ if n2 == "*":
+ state = BLOCK
+ i += 2
+ continue
+
+ # string/text/char
+ if c == '"':
+ if i + 2 < n and s[i:i + 3] == '"""':
+ state = TEXT
+ i += 3
+ continue
+ state = STRING
+ i += 1
+ continue
+ if c == "'":
+ state = CHAR
i += 1
continue
- elif state == LINE:
- if ch == '\n': state = NORMAL
- i += 1; continue
+ # ---- cancel/advance on delimiters that cannot belong to a type header ----
+ if c == ";":
+ pending_kind = None
+ pending_kind_pos = -1
+ pending_name = None
+ can_start_decl = True
+ i += 1
+ continue
- elif state == BLOCK:
- if ch == '*' and i + 1 < n and s[i + 1] == '/':
- state = NORMAL; i += 2
- else:
- i += 1
+ if c == "=":
+ pending_kind = None
+ pending_kind_pos = -1
+ pending_name = None
+ can_start_decl = False
+ i += 1
continue
- elif state == STRING:
- if ch == '\\': i += 2
- else:
- if ch == '"': state = NORMAL
- i += 1
+ if c == "(" and pending_kind is not None and pending_kind != "record":
+ # `class|interface|enum` headers never contain '('
+ pending_kind = None
+ pending_kind_pos = -1
+ pending_name = None
+ can_start_decl = False
+ i += 1
continue
- else: # CHAR
- if ch == '\\': i += 2
- else:
- if ch == "'": state = NORMAL
- i += 1
+ # annotations at decl-start: consume and remain in decl-start mode
+ if c == "@" and can_start_decl and pending_kind is None:
+ # Special-case: allow "@interface Foo { ... }" as an inner "interface" decl.
+ if (i + 10 <= n and s[i + 1:i + 10] == "interface" and
+ (i + 10 == n or not _is_ident_part(s[i + 10]))):
+ i += 1 # next loop sees "interface"
+ continue
+
+ i = _consume_annotations(i)
+ can_start_decl = True
continue
- return ("\n".join(imports) + ("\n" if imports else ""))
+ # identifier tokenizing
+ if _is_ident_start(c):
+ j = i + 1
+ while j < n and _is_ident_part(s[j]):
+ j += 1
+ tok = s[i:j]
+
+ # After kind keyword, accept very next identifier as the type name.
+ if pending_kind is not None and pending_name is None:
+ pending_name = tok
+ can_start_decl = False
+ i = j
+ continue
- # -----------------------------------
- # Scan: find all named type regions
- # -----------------------------------
- def _find_all_named_types(s):
- kinds = ('class', 'interface', 'enum', 'record')
- regions = [] # tuples: (name, kind, start, end)
- i, n = 0, len(s)
- steps, limit = 0, 10_000_000
+ # Only recognize kind keywords at declaration starts.
+ if pending_kind is None and can_start_decl:
+ if tok in class_mods:
+ can_start_decl = True
+ elif tok in class_kinds:
+ pending_kind = tok
+ pending_kind_pos = i
+ pending_name = None
+ can_start_decl = False
+ else:
+ can_start_decl = False
- while i < n:
- steps += 1
- if steps > limit:
- break
+ i = j
+ continue
- i = _skip_ws_comments(s, i)
- if i >= n:
- break
+ # Otherwise we're in code/header; not a new decl start.
+ can_start_decl = False
+ i = j
+ continue
- c = s[i]
- if c == '"':
- i = _skip_string(s, i); continue
- if c == "'":
- i = _skip_char(s, i); continue
+ # brace open
+ if c == "{":
+ brace_depth += 1
+ can_start_decl = True
- i = _skip_leading_annotations(s, i)
- if i >= n:
- break
+ if pending_kind is not None and pending_name is not None:
+ is_inner = bool(named_type_depth_stack)
- matched_kind = None
- for kw in kinds:
- if _word_at(s, i, kw):
- matched_kind = kw
- break
- if not matched_kind:
- i += 1
- continue
+ # FIXED: anchor to the declaration line; do not walk into previous statements.
+ sig_start = _sig_start_for_type(pending_kind_pos)
- # consume keyword
- i += len(matched_kind)
- i = _skip_ws_comments(s, i)
+ body_close = _match_balanced_brace(i, n)
+ if body_close != -1 and is_inner:
+ raw = s[sig_start:body_close + 1]
- # read type name
- name, i2 = _read_ident(s, i)
- if not name:
- i += 1
- continue
- i = i2
+ cleaned = raw
+ if "/" in cleaned and (("//" in cleaned) or ("/*" in cleaned)):
+ cleaned = _strip_comments_keep_strings_range(cleaned, 0, len(cleaned))
+ if "@" in cleaned:
+ cleaned = _strip_annotations_in_range(cleaned, 0, len(cleaned))
- # consume header up to '{'
- while i < n:
- i = _skip_ws_comments(s, i)
- if i >= n or s[i] == '{':
- break
- if s[i] == '"':
- i = _skip_string(s, i); continue
- if s[i] == "'":
- i = _skip_char(s, i); continue
- if _is_ident_char(s[i]):
- _, i = _read_ident(s, i); continue
- if s[i] == '<':
- depth = 1; i += 1
- while i < n and depth > 0:
- i = _skip_ws_comments(s, i)
- if i >= n: break
- ch = s[i]
- if ch == '"': i = _skip_string(s, i); continue
- if ch == "'": i = _skip_char(s, i); continue
- if ch == '<': depth += 1; i += 1; continue
- if ch == '>': depth -= 1; i += 1; continue
- i += 1
+ cleaned = cleaned.strip()
+ if cleaned:
+ inner_classes.append(cleaned)
+
+ named_type_depth_stack.append(brace_depth)
+
+ pending_kind = None
+ pending_kind_pos = -1
+ pending_name = None
+
+ i += 1
continue
- if s[i] == '(':
- i = _skip_parenthesized(s, i); continue
- i += 1
- if i >= n or s[i] != '{':
+ pending_kind = None
+ pending_kind_pos = -1
+ pending_name = None
+
+ i += 1
continue
- body_open = i
- body_close = _match_brace(s, body_open)
- if body_close is None:
- break
+ # brace close
+ if c == "}":
+ if brace_depth > 0:
+ if named_type_depth_stack and named_type_depth_stack[-1] == brace_depth:
+ named_type_depth_stack.pop()
+ brace_depth -= 1
- regions.append((name, matched_kind, body_open, body_close))
- i = body_close + 1
+ pending_kind = None
+ pending_kind_pos = -1
+ pending_name = None
- return regions
+ can_start_decl = True
+ i += 1
+ continue
- # -----------------------------------
- # Assemble results for inner classes
- # -----------------------------------
- imports_block = _collect_top_level_imports(src)
- imports_prefix = imports_block if (not imports_block or imports_block.endswith("\n")) else imports_block + "\n"
+ # any other non-ws char likely means we're in an expression/statement
+ if not c.isspace():
+ can_start_decl = False
- regions = _find_all_named_types(src) # (name, kind, start, end)
- regions.sort(key=lambda t: (t[2], -(t[3] - t[2]))) # by start, longer first
+ i += 1
+ continue
- results = []
- stack = [] # elements: (end, sig_start)
+ # ---------- non-NORMAL states ----------
+ if state == LINE:
+ if c == "\n":
+ state = NORMAL
+ i += 1
+ continue
- for name, kind, start, end in regions:
- # pop completed outers
- while stack and start >= stack[-1][0]:
- stack.pop()
+ if state == BLOCK:
+ if c == "*" and i + 1 < n and s[i + 1] == "/":
+ state = NORMAL
+ i += 2
+ continue
+ i += 1
+ continue
- # compute header/signature start
- hdr = _find_left_boundary(src, start)
- hdr = _skip_ws_comments(src, hdr)
- sig_start = _skip_leading_annotations(src, hdr)
- sig_start = _skip_ws_comments(src, sig_start)
+ if state == STRING:
+ if c == "\\":
+ i += 2
+ continue
+ if c == '"':
+ state = NORMAL
+ i += 1
+ continue
- is_inner = bool(stack)
- if is_inner and end >= sig_start:
- type_src = src[sig_start:end + 1]
- results.append(imports_prefix + type_src)
+ if state == CHAR:
+ if c == "\\":
+ i += 2
+ continue
+ if c == "'":
+ state = NORMAL
+ i += 1
+ continue
- stack.append((end, sig_start))
+ # TEXT
+ if i + 2 < n and s[i:i + 3] == '"""':
+ state = NORMAL
+ i += 3
+ continue
+ i += 1
- return results
\ No newline at end of file
+ return inner_classes
\ No newline at end of file
diff --git a/src/exploit_iq_commons/utils/java_utils.py b/src/exploit_iq_commons/utils/java_utils.py
index 47f6c33a..3bcbd91d 100644
--- a/src/exploit_iq_commons/utils/java_utils.py
+++ b/src/exploit_iq_commons/utils/java_utils.py
@@ -58,6 +58,19 @@
re.X,
)
+# Matches real Java type headers like "class Foo", not ".class"
+_JAVA_TYPE_HEADER_RE = re.compile(r'(? str:
- """
- Return the simple name of the first class/interface/enum/record declared
- anywhere in the given Java source (top-level or nested). Makes no assumptions
- about leading comments, javadocs, annotations, package/imports, etc.
-
- - Skips // line comments, /* ... */ (incl. Javadoc) block comments, strings "..."
- and character literals '...'.
- - Handles 'record' headers (e.g., `record Point(int x, int y) { ... }`).
- - Matches annotation type declarations too (e.g., `public @interface Foo {}`),
- because the token `interface` is still present.
- - Returns the *simple* identifier (e.g., 'Point').
- """
- n = len(src)
- i = 0
-
- # States
- NORMAL, LINE, BLOCK, STRING, CHAR = 0, 1, 2, 3, 4
- state = NORMAL
-
- KW = ("class", "interface", "enum", "record")
-
- def is_ws(ch: str) -> bool:
- return ch <= " " and ch in " \t\r\n\f\v"
+# @lru_cache(maxsize=150000)
+# def get_type_name(src: str) -> str:
+# """
+# Return the simple name of the first class/interface/enum/record declared
+# anywhere in the given Java source (top-level or nested). Makes no assumptions
+# about leading comments, javadocs, annotations, package/imports, etc.
+#
+# - Skips // line comments, /* ... */ (incl. Javadoc) block comments, strings "..."
+# and character literals '...'.
+# - Handles 'record' headers (e.g., `record Point(int x, int y) { ... }`).
+# - Matches annotation type declarations too (e.g., `public @interface Foo {}`),
+# because the token `interface` is still present.
+# - Returns the *simple* identifier (e.g., 'Point').
+# """
+# n = len(src)
+# i = 0
+#
+# # States
+# NORMAL, LINE, BLOCK, STRING, CHAR = 0, 1, 2, 3, 4
+# state = NORMAL
+#
+# KW = ("class", "interface", "enum", "record")
+#
+# def is_ws(ch: str) -> bool:
+# return ch <= " " and ch in " \t\r\n\f\v"
+#
+# def is_id_start(ch: str) -> bool:
+# # Java allows '_' and '$'; also allow unicode letters
+# return (
+# ch == "_" or ch == "$" or
+# ("A" <= ch <= "Z") or ("a" <= ch <= "z") or
+# (ch >= "\u0080" and ch.isalpha())
+# )
+#
+# def is_id_part(ch: str) -> bool:
+# return is_id_start(ch) or ("0" <= ch <= "9")
+#
+# def word_at(pos: int, word: str) -> bool:
+# """Return True if 'word' starts at src[pos] with Java-style word boundaries."""
+# end = pos + len(word)
+# if end > n or src[pos:end] != word:
+# return False
+# # left boundary: start of file or not an identifier part
+# if pos > 0 and (src[pos - 1].isalnum() or src[pos - 1] in "_$"):
+# return False
+# # right boundary: end of file or not an identifier part
+# if end < n and (src[end].isalnum() or src[end] in "_$"):
+# return False
+# return True
+#
+# while i < n:
+# ch = src[i]
+#
+# if state == NORMAL:
+# # comment/string/char entry
+# if ch == "/":
+# if i + 1 < n and src[i + 1] == "/":
+# state = LINE
+# i += 2
+# continue
+# if i + 1 < n and src[i + 1] == "*":
+# state = BLOCK
+# i += 2
+# continue
+# elif ch == '"':
+# state = STRING
+# i += 1
+# continue
+# elif ch == "'":
+# state = CHAR
+# i += 1
+# continue
+#
+# # Try to match a type keyword at this position
+# for kw in KW:
+# if word_at(i, kw):
+# j = i + len(kw)
+# # skip whitespace between keyword and identifier
+# while j < n and is_ws(src[j]):
+# j += 1
+# # Next must be a valid Java identifier start (type name)
+# if j < n and is_id_start(src[j]):
+# k = j + 1
+# while k < n and is_id_part(src[k]):
+# k += 1
+# return src[j:k]
+# # If not a valid identifier here, keep scanning
+# i += 1
+#
+# elif state == LINE:
+# if ch == "\n":
+# state = NORMAL
+# i += 1
+#
+# elif state == BLOCK:
+# if ch == "*" and i + 1 < n and src[i + 1] == "/":
+# state = NORMAL
+# i += 2
+# else:
+# i += 1
+#
+# elif state == STRING:
+# if ch == "\\" and i + 1 < n:
+# i += 2 # skip escaped
+# continue
+# if ch == '"':
+# state = NORMAL
+# i += 1
+#
+# elif state == CHAR:
+# if ch == "\\" and i + 1 < n:
+# i += 2
+# continue
+# if ch == "'":
+# state = NORMAL
+# i += 1
+#
+# return ""
- def is_id_start(ch: str) -> bool:
- # Java allows '_' and '$'; also allow unicode letters
- return (
- ch == "_" or ch == "$" or
- ("A" <= ch <= "Z") or ("a" <= ch <= "z") or
- (ch >= "\u0080" and ch.isalpha())
- )
-
- def is_id_part(ch: str) -> bool:
- return is_id_start(ch) or ("0" <= ch <= "9")
-
- def word_at(pos: int, word: str) -> bool:
- """Return True if 'word' starts at src[pos] with Java-style word boundaries."""
- end = pos + len(word)
- if end > n or src[pos:end] != word:
- return False
- # left boundary: start of file or not an identifier part
- if pos > 0 and (src[pos - 1].isalnum() or src[pos - 1] in "_$"):
- return False
- # right boundary: end of file or not an identifier part
- if end < n and (src[end].isalnum() or src[end] in "_$"):
- return False
- return True
-
- while i < n:
- ch = src[i]
-
- if state == NORMAL:
- # comment/string/char entry
- if ch == "/":
- if i + 1 < n and src[i + 1] == "/":
- state = LINE
- i += 2
- continue
- if i + 1 < n and src[i + 1] == "*":
- state = BLOCK
- i += 2
- continue
- elif ch == '"':
- state = STRING
- i += 1
- continue
- elif ch == "'":
- state = CHAR
- i += 1
- continue
-
- # Try to match a type keyword at this position
- for kw in KW:
- if word_at(i, kw):
- j = i + len(kw)
- # skip whitespace between keyword and identifier
- while j < n and is_ws(src[j]):
- j += 1
- # Next must be a valid Java identifier start (type name)
- if j < n and is_id_start(src[j]):
- k = j + 1
- while k < n and is_id_part(src[k]):
- k += 1
- return src[j:k]
- # If not a valid identifier here, keep scanning
- i += 1
+@lru_cache(maxsize=2500000)
+def is_java_method(source: str) -> bool:
+ s = source
- elif state == LINE:
- if ch == "\n":
- state = NORMAL
- i += 1
+ # avoid scanning twice for '(' and '->'
+ has_paren = "(" in s
+ has_arrow = "->" in s
+ if not has_paren and not has_arrow:
+ return False
- elif state == BLOCK:
- if ch == "*" and i + 1 < n and src[i + 1] == "/":
- state = NORMAL
- i += 2
- else:
- i += 1
+ n = len(s)
+ find_in_source = s.find
+ startswith = s.startswith
+ ws = JAVA_METHOD_WS_CHARS
- elif state == STRING:
- if ch == "\\" and i + 1 < n:
- i += 2 # skip escaped
+ # --- FAST REJECT #1: skip leading whitespace/comments and reject FreeMarker templates ---
+ # This avoids scanning huge .ftl-like files (your ArrowType.java template input).
+ j = 0
+ while j < n:
+ ch = s[j]
+ if ch in ws:
+ j += 1
+ continue
+ if ch == "/" and j + 1 < n:
+ c2 = s[j + 1]
+ if c2 == "/":
+ nl = find_in_source("\n", j + 2)
+ j = n if nl == -1 else nl + 1
continue
- if ch == '"':
- state = NORMAL
- i += 1
-
- elif state == CHAR:
- if ch == "\\" and i + 1 < n:
- i += 2
+ if c2 == "*":
+ end = find_in_source("*/", j + 2)
+ j = n if end == -1 else end + 2
continue
- if ch == "'":
- state = NORMAL
- i += 1
+ break
- return ""
+ if j < n and (s.startswith("<#", j) or s.startswith("<@", j)):
+ return False
-@lru_cache(maxsize=2500000)
-def is_java_method(source: str) -> bool:
- """
- Return True iff the given Java source snippet is a *method or constructor*
- declaration, or (fallback) contains a lambda ('->') outside of comments/strings.
-
- Notes:
- - Methods with explicit return type: detected.
- - Constructors (no return type) are treated as methods too.
- - Control-flow constructs are excluded.
- - A lambda arrow outside comments/strings/text-blocks returns True.
- """
- s = source
- # Safe early reject: no method/ctor is possible without '('; no lambda possible without '->'.
- if "(" not in s and "->" not in s:
+ # --- FAST REJECT #2: compilation unit headers ("package"/"import") ---
+ # Safe for your prior examples (method/ctor/lambda snippets) and avoids full-file scans.
+ leading_ident_match = _LEADING_IDENT_RE.match(s, j)
+ if leading_ident_match and leading_ident_match.group(0) in ("package", "import"):
return False
- n = len(s)
+ # Conservative “might be a top-level type decl somewhere” flag:
+ # Use a real "type header" detector, not substring "class" (to avoid `.class` false positives).
+ maybe_type_decl = bool(has_arrow and _JAVA_TYPE_HEADER_RE.search(s))
+
i = 0
NORMAL, LINE, BLOCK, STRING, CHAR, TEXT = 0, 1, 2, 3, 4, 5
@@ -864,10 +907,6 @@ def is_java_method(source: str) -> bool:
brace = 0
saw_lambda = False
- find_in_source = s.find
- startswith = s.startswith
- ws = JAVA_METHOD_WS_CHARS
-
def is_id_start(ch: str) -> bool:
return ch == "_" or ("A" <= ch <= "Z") or ("a" <= ch <= "z") or (ch >= "\u0080" and ch.isalpha())
@@ -878,13 +917,9 @@ def is_id_part(ch: str) -> bool:
(ch >= "\u0080" and ch.isalnum())
)
- def _skip_quoted(j: int, quote: str) -> int:
- """
- Skip to the first unescaped closing quote starting at j (j is after opening quote).
- Returns index right after the closing quote, or n if none.
- """
- while j < n:
- end = find_in_source(quote, j)
+ def _skip_quoted(jq: int, quote: str) -> int:
+ while jq < n:
+ end = find_in_source(quote, jq)
if end == -1:
return n
k = end - 1
@@ -894,92 +929,92 @@ def _skip_quoted(j: int, quote: str) -> int:
k -= 1
if back % 2 == 0:
return end + 1
- j = end + 1
+ jq = end + 1
return n
- def skip_paren(j: int) -> int:
+ def skip_paren(jp: int) -> int:
depth = 1
st = NORMAL
- while j < n and depth > 0:
- ch = s[j]
+ while jp < n and depth > 0:
+ chp = s[jp]
if st == NORMAL:
- if ch == "/":
- if j + 1 < n:
- c2 = s[j + 1]
+ if chp == "/":
+ if jp + 1 < n:
+ c2 = s[jp + 1]
if c2 == "/":
st = LINE
- j += 2
+ jp += 2
continue
if c2 == "*":
st = BLOCK
- j += 2
+ jp += 2
continue
- elif ch == '"':
- if j + 2 < n and s[j+1] == '"' and s[j+2] == '"':
+ elif chp == '"':
+ if jp + 2 < n and s[jp + 1] == '"' and s[jp + 2] == '"':
st = TEXT
- j += 3
+ jp += 3
continue
st = STRING
- j += 1
+ jp += 1
continue
- elif ch == "'":
+ elif chp == "'":
st = CHAR
- j += 1
+ jp += 1
continue
- elif ch == "(":
+ elif chp == "(":
depth += 1
- j += 1
+ jp += 1
continue
- elif ch == ")":
+ elif chp == ")":
depth -= 1
- j += 1
+ jp += 1
continue
else:
- j += 1
+ jp += 1
continue
elif st == LINE:
- nl = find_in_source("\n", j)
- j = n if nl == -1 else nl + 1
+ nl = find_in_source("\n", jp)
+ jp = n if nl == -1 else nl + 1
st = NORMAL
elif st == BLOCK:
- end = find_in_source("*/", j)
- j = n if end == -1 else end + 2
+ end = find_in_source("*/", jp)
+ jp = n if end == -1 else end + 2
st = NORMAL
elif st == STRING:
- j = _skip_quoted(j, '"')
+ jp = _skip_quoted(jp, '"')
st = NORMAL
elif st == CHAR:
- j = _skip_quoted(j, "'")
+ jp = _skip_quoted(jp, "'")
st = NORMAL
else: # TEXT
while True:
- end = find_in_source('"""', j)
+ end = find_in_source('"""', jp)
if end == -1:
- j = n
+ jp = n
break
k = end - 1
back = 0
while k >= 0 and s[k] == "\\":
back += 1
k -= 1
- j = end + 3
+ jp = end + 3
if back % 2 == 0:
break
st = NORMAL
- return j
+ return jp
- def read_ident(j: int) -> tuple[Optional[str], int]:
- if j >= n or not is_id_start(s[j]):
- return None, j
- k = j + 1
+ def read_ident(jr: int) -> tuple[Optional[str], int]:
+ if jr >= n or not is_id_start(s[jr]):
+ return None, jr
+ k = jr + 1
while k < n and is_id_part(s[k]):
k += 1
- return s[j:k], k
+ return s[jr:k], k
- def read_qual_ident(j: int) -> tuple[Optional[str], int]:
- name, k = read_ident(j)
+ def read_qual_ident(jr: int) -> tuple[Optional[str], int]:
+ name, k = read_ident(jr)
if not name:
- return None, j
+ return None, jr
while k < n and s[k] == ".":
nm2, k2 = read_ident(k + 1)
if not nm2:
@@ -987,139 +1022,132 @@ def read_qual_ident(j: int) -> tuple[Optional[str], int]:
k = k2
return name, k
- def skip_ws_comments(j: int) -> int:
- while j < n:
- ch = s[j]
- if ch in ws:
- j += 1
+ def skip_ws_comments(jw: int) -> int:
+ while jw < n:
+ chw = s[jw]
+ if chw in ws:
+ jw += 1
continue
- if ch == "/" and j + 1 < n:
- c2 = s[j + 1]
+ if chw == "/" and jw + 1 < n:
+ c2 = s[jw + 1]
if c2 == "/":
- nl = find_in_source("\n", j + 2)
- j = n if nl == -1 else nl + 1
+ nl = find_in_source("\n", jw + 2)
+ jw = n if nl == -1 else nl + 1
continue
if c2 == "*":
- end = find_in_source("*/", j + 2)
- j = n if end == -1 else end + 2
+ end = find_in_source("*/", jw + 2)
+ jw = n if end == -1 else end + 2
continue
break
- return j
+ return jw
- def has_dot_between(j: int, k: int) -> bool:
- j = skip_ws_comments(j)
- while j < k:
- ch = s[j]
- if ch == ".":
+ def has_dot_between(jd: int, kd: int) -> bool:
+ jd = skip_ws_comments(jd)
+ while jd < kd:
+ chd = s[jd]
+ if chd == ".":
return True
- if ch in ws:
- j += 1
+ if chd in ws:
+ jd += 1
continue
- if ch == "/" and j + 1 < n:
- c2 = s[j + 1]
+ if chd == "/" and jd + 1 < n:
+ c2 = s[jd + 1]
if c2 == "/":
- nl = find_in_source("\n", j + 2)
- j = n if nl == -1 else nl + 1
+ nl = find_in_source("\n", jd + 2)
+ jd = n if nl == -1 else nl + 1
continue
if c2 == "*":
- end = find_in_source("*/", j + 2)
- j = n if end == -1 else end + 2
+ end = find_in_source("*/", jd + 2)
+ jd = n if end == -1 else end + 2
continue
- j += 1
+ jd += 1
return False
def plausible_method_decl(name_start: int) -> bool:
- """
- Must see a real return type token before the name (method),
- OR see only annotations/modifiers/generics/comments before the name (constructor).
- """
- j = name_start
- j0 = max(s.rfind("\n", 0, j), s.rfind("{", 0, j), s.rfind(";", 0, j)) + 1
- j = skip_ws_comments(j0)
+ jx = name_start
+ j0 = max(s.rfind("\n", 0, jx), s.rfind("{", 0, jx), s.rfind(";", 0, jx)) + 1
+ jx = skip_ws_comments(j0)
broken = False
- while j < name_start:
- if s[j] == JAVA_ANNOTATION_SYMBOL:
- _, j = read_qual_ident(j + 1)
- j = skip_ws_comments(j)
- if j < n and s[j] == "(":
- j = skip_paren(j + 1)
- j = skip_ws_comments(j)
+ while jx < name_start:
+ if s[jx] == JAVA_ANNOTATION_SYMBOL:
+ _, jx = read_qual_ident(jx + 1)
+ jx = skip_ws_comments(jx)
+ if jx < n and s[jx] == "(":
+ jx = skip_paren(jx + 1)
+ jx = skip_ws_comments(jx)
continue
- if s[j] == "<":
+ if s[jx] == "<":
depth = 1
- j += 1
- while j < n and depth > 0:
- if s[j] == "<":
+ jx += 1
+ while jx < n and depth > 0:
+ if s[jx] == "<":
depth += 1
- elif s[j] == ">":
+ elif s[jx] == ">":
depth -= 1
- j += 1
- j = skip_ws_comments(j)
+ jx += 1
+ jx = skip_ws_comments(jx)
continue
- tok, k = read_ident(j)
+ tok, k = read_ident(jx)
if tok:
if tok in JAVA_METHOD_PRIM_TYPES:
return True
-
if tok in JAVA_METHOD_METH_MODS:
- j = skip_ws_comments(k)
+ jx = skip_ws_comments(k)
continue
-
if tok == "non":
k2 = skip_ws_comments(k)
if k2 < n and s[k2] == "-":
k3 = skip_ws_comments(k2 + 1)
id2, j2 = read_ident(k3)
if id2 == "sealed":
- j = skip_ws_comments(j2)
+ jx = skip_ws_comments(j2)
continue
- # reference return type start
- j = k
- j = skip_ws_comments(j)
+ jx = k
+ jx = skip_ws_comments(jx)
- if j < n and s[j] == "<":
+ if jx < n and s[jx] == "<":
depth = 1
- j += 1
- while j < n and depth > 0:
- if s[j] == "<":
+ jx += 1
+ while jx < n and depth > 0:
+ if s[jx] == "<":
depth += 1
- elif s[j] == ">":
+ elif s[jx] == ">":
depth -= 1
- j += 1
- j = skip_ws_comments(j)
+ jx += 1
+ jx = skip_ws_comments(jx)
- while j + 1 < n and s[j] == "[" and s[j + 1] == "]":
- j += 2
- j = skip_ws_comments(j)
+ while jx + 1 < n and s[jx] == "[" and s[jx + 1] == "]":
+ jx += 2
+ jx = skip_ws_comments(jx)
tcount = 0
while tcount < 4:
- j = skip_ws_comments(j)
- if j < n and s[j] == "&":
- j = skip_ws_comments(j + 1)
- nm2, k2 = read_ident(j)
+ jx = skip_ws_comments(jx)
+ if jx < n and s[jx] == "&":
+ jx = skip_ws_comments(jx + 1)
+ nm2, k2 = read_ident(jx)
if not nm2:
break
- j = k2
- j = skip_ws_comments(j)
- if j < n and s[j] == "<":
+ jx = k2
+ jx = skip_ws_comments(jx)
+ if jx < n and s[jx] == "<":
depth = 1
- j += 1
- while j < n and depth > 0:
- if s[j] == "<":
+ jx += 1
+ while jx < n and depth > 0:
+ if s[jx] == "<":
depth += 1
- elif s[j] == ">":
+ elif s[jx] == ">":
depth -= 1
- j += 1
- j = skip_ws_comments(j)
- while j + 1 < n and s[j] == "[" and s[j + 1] == "]":
- j += 2
- j = skip_ws_comments(j)
+ jx += 1
+ jx = skip_ws_comments(jx)
+ while jx + 1 < n and s[jx] == "[" and s[jx + 1] == "]":
+ jx += 2
+ jx = skip_ws_comments(jx)
tcount += 1
else:
break
@@ -1129,13 +1157,10 @@ def plausible_method_decl(name_start: int) -> bool:
broken = True
break
- # No return type found, but nothing invalid encountered -> constructor
- return not broken
+ return not broken # constructor
while i < n:
if state == NORMAL:
- # Big win: when inside braces, skip boring characters in C (regex search),
- # only landing on characters that can change state or matter for lambda/braces.
if brace > 0:
m = _BRACE_NORMAL_SPECIAL_RE.search(s, i)
if not m:
@@ -1143,9 +1168,10 @@ def plausible_method_decl(name_start: int) -> bool:
i = m.start()
ch = s[i]
- # keep ordering consistent with original: lambda before comment/string checks
if ch == "-" and i + 1 < n and s[i + 1] == ">":
saw_lambda = True
+ if not maybe_type_decl:
+ return True
i += 2
continue
@@ -1164,7 +1190,7 @@ def plausible_method_decl(name_start: int) -> bool:
continue
if ch == '"':
- if i + 2 < n and s[i+1] == '"' and s[i+2] == '"':
+ if i + 2 < n and s[i + 1] == '"' and s[i + 2] == '"':
state = TEXT
i += 3
continue
@@ -1188,15 +1214,15 @@ def plausible_method_decl(name_start: int) -> bool:
i += 1
continue
- # non-handled special (should not happen)
i += 1
continue
- # brace == 0: full logic (header-level)
ch = s[i]
if ch == "-" and i + 1 < n and s[i + 1] == ">":
saw_lambda = True
+ if not maybe_type_decl:
+ return True
i += 2
continue
@@ -1212,7 +1238,7 @@ def plausible_method_decl(name_start: int) -> bool:
i += 2
continue
elif ch == '"':
- if i + 2 < n and s[i+1] == '"' and s[i+2] == '"':
+ if i + 2 < n and s[i + 1] == '"' and s[i + 2] == '"':
state = TEXT
i += 3
continue
@@ -1232,12 +1258,10 @@ def plausible_method_decl(name_start: int) -> bool:
i += 1
continue
- # brace==0 path
if ch in ws:
i += 1
continue
- # skip annotations at top level
if ch == JAVA_ANNOTATION_SYMBOL:
_, i = read_qual_ident(i + 1)
i = skip_ws_comments(i)
@@ -1245,43 +1269,46 @@ def plausible_method_decl(name_start: int) -> bool:
i = skip_paren(i + 1)
continue
- # skip modifiers & detect type headers early
if is_id_start(ch):
- ident, j = read_ident(i)
+ ident, j2 = read_ident(i)
+
if ident in JAVA_METHOD_CLASS_MODS:
- i = skip_ws_comments(j)
+ i = skip_ws_comments(j2)
continue
+
if ident == "non":
- k = skip_ws_comments(j)
+ k = skip_ws_comments(j2)
if k < n and s[k] == "-":
k2 = skip_ws_comments(k + 1)
- id2, j2 = read_ident(k2)
+ id2, j3 = read_ident(k2)
if id2 == "sealed":
- i = skip_ws_comments(j2)
+ i = skip_ws_comments(j3)
continue
+
if ident in JAVA_METHOD_CLASS_KINDS:
- i2 = skip_ws_comments(j)
- name, j2 = read_ident(i2)
+ i2 = skip_ws_comments(j2)
+ name, j4 = read_ident(i2)
if name:
return False
- i = j
+ i = j2
continue
- # potential method/ctor header
+ i = j2
+ continue
+
if ch == "(":
- j = i - 1
- while j >= 0 and s[j] in ws:
- j -= 1
- end = j + 1
- while j >= 0 and is_id_part(s[j]):
- j -= 1
- name = s[j+1:end] if end > (j + 1) else None
+ jx = i - 1
+ while jx >= 0 and s[jx] in ws:
+ jx -= 1
+ end = jx + 1
+ while jx >= 0 and is_id_part(s[jx]):
+ jx -= 1
+ name = s[jx + 1:end] if end > (jx + 1) else None
if name and (name not in JAVA_METHOD_CTRL_WORDS) and not has_dot_between(end, i):
k = skip_paren(i + 1)
q = skip_ws_comments(k)
- # optional throws
if startswith("throws", q) and (q + 6 == n or not is_id_part(s[q + 6])):
q = skip_ws_comments(q + 6)
while q < n and s[q] not in "{;":
@@ -1309,7 +1336,7 @@ def plausible_method_decl(name_start: int) -> bool:
q = skip_ws_comments(q)
if q < n and s[q] in "{;":
- if plausible_method_decl(j + 1):
+ if plausible_method_decl(jx + 1):
return True
i += 1
@@ -1378,9 +1405,9 @@ def is_java_type(source: str) -> bool:
brace = 0
ws = " \t\r\n\f\v"
- class_kinds = {"class", "interface", "enum", "record"}
- class_mods = {"public","protected","private","abstract","final","static","sealed","strictfp"}
- meth_ctrl = {"if","for","while","switch","catch","synchronized","return","new","case"}
+ class_kinds = JAVA_METHOD_CLASS_KINDS
+ class_mods = JAVA_METHOD_CLASS_MODS
+ meth_ctrl = JAVA_METHOD_CTRL_WORDS
def is_ws(ch): return ch in ws
def is_id_start(ch): return ch == '_' or ch == '$' or ('A' <= ch <= 'Z') or ('a' <= ch <= 'z') or (ch >= '\u0080' and ch.isalpha())
@@ -1395,28 +1422,41 @@ def read_ident(j: int):
def skip_ws_comments(j: int) -> int:
while j < n:
ch = s[j]
- if is_ws(ch): j += 1; continue
+ if is_ws(ch):
+ j += 1
+ continue
if ch == '/' and j + 1 < n:
c2 = s[j + 1]
if c2 == '/':
- j = (s.find('\n', j + 2) + 1) or n; continue
+ nl = s.find('\n', j + 2)
+ j = n if nl == -1 else nl + 1
+ continue
if c2 == '*':
end = s.find("*/", j + 2)
- j = n if end == -1 else end + 2; continue
+ j = n if end == -1 else end + 2
+ continue
break
return j
def skip_paren_balanced(j: int) -> int:
- depth = 1; st = NORMAL
+ depth = 1
+ st = NORMAL
while j < n and depth > 0:
ch = s[j]
if st == NORMAL:
if ch == '/':
+ # FIX: if it's not // or /*, it's an operator; advance!
if j + 1 < n and s[j+1] == '/':
- j = (s.find('\n', j + 2) + 1) or n; continue
+ nl = s.find('\n', j + 2)
+ j = n if nl == -1 else nl + 1
+ continue
if j + 1 < n and s[j+1] == '*':
end = s.find('*/', j + 2)
- j = n if end == -1 else end + 2; continue
+ j = n if end == -1 else end + 2
+ continue
+ j += 1
+ continue
+
elif ch == '"':
if j + 2 < n and s[j+1] == '"' and s[j+2] == '"':
st = TEXT; j += 3; continue
@@ -1429,18 +1469,21 @@ def skip_paren_balanced(j: int) -> int:
depth -= 1; j += 1; continue
else:
j += 1; continue
+
elif st == STRING:
while j < n:
c2 = s[j]; j += 1
if c2 == '\\': j += 1
elif c2 == '"': break
st = NORMAL
+
elif st == CHAR:
while j < n:
c2 = s[j]; j += 1
if c2 == '\\': j += 1
elif c2 == "'": break
st = NORMAL
+
else: # TEXT
while True:
end = s.find('"""', j)
@@ -1450,6 +1493,7 @@ def skip_paren_balanced(j: int) -> int:
j = end + 3
if back % 2 == 0: break
st = NORMAL
+
return j
def skip_angles_balanced(j: int) -> int:
@@ -1460,10 +1504,13 @@ def skip_angles_balanced(j: int) -> int:
elif ch == '>': depth -= 1
elif ch == '/':
if j + 1 < n and s[j+1] == '/':
- j = (s.find('\n', j + 2) + 1) or n; continue
+ nl = s.find('\n', j + 2)
+ j = n if nl == -1 else nl + 1
+ continue
if j + 1 < n and s[j+1] == '*':
end = s.find('*/', j + 2)
- j = n if end == -1 else end + 2; continue
+ j = n if end == -1 else end + 2
+ continue
j += 1
return j
@@ -1476,10 +1523,13 @@ def has_dot_between(a: int, b: int) -> bool:
j += 1; continue
if ch == '/' and j + 1 < n:
if s[j+1] == '/':
- j = (s.find('\n', j + 2) + 1) or n; continue
+ nl = s.find('\n', j + 2)
+ j = n if nl == -1 else nl + 1
+ continue
if s[j+1] == '*':
end = s.find('*/', j + 2)
- j = n if end == -1 else end + 2; continue
+ j = n if end == -1 else end + 2
+ continue
j += 1
return False
@@ -1515,12 +1565,10 @@ def skip_one_type(p: int) -> int:
# NEW: skip comma-separated type list after extends/implements/permits
def skip_type_list(p: int) -> int:
p = skip_ws_comments(p)
- tcount = 0
while p < n:
p2 = skip_one_type(p)
if p2 == p: break
p = skip_ws_comments(p2)
- tcount += 1
if p < n and s[p] == ',':
p = skip_ws_comments(p + 1)
continue
@@ -1687,8 +1735,8 @@ def create_inheritance_map(java_types: List[Document]) -> dict[Tuple[str, str],
"""
Build a jar-aware inheritance/implements map.
- Key: (simple_name, source_path)
- Value: ordered, deduped list of (simple_name, source_path) including:
+ Key: (fqcn, source_path)
+ Value: ordered, deduped list of (fqcn, source_path) including:
- self
- ancestors (class chain or interface parents)
- implemented interfaces (direct, via ancestors, and via super-interfaces)
@@ -1701,7 +1749,7 @@ def create_inheritance_map(java_types: List[Document]) -> dict[Tuple[str, str],
External/unseen types (e.g., JDK classes/interfaces) are INCLUDED in the
upward/interface lists as placeholders with the current type's source path:
- (simple_name, current_source_of_key)
+ (fqcn_or_token, current_source_of_key)
Placeholders are NOT used to wire reverse edges (subclasses/implementers).
Parsing guarantees:
@@ -1717,6 +1765,60 @@ def create_inheritance_map(java_types: List[Document]) -> dict[Tuple[str, str],
"""
# ---------------- helpers ----------------
+
+ def _balance_body_end(s: str, open_brace_idx: int) -> int:
+ """Return index just after the matching '}' for body that starts at open_brace_idx."""
+ depth = 1
+ i = open_brace_idx + 1
+ n = len(s)
+ while i < n and depth:
+ ch = s[i]
+ if ch == '{':
+ depth += 1
+ elif ch == '}':
+ depth -= 1
+ i += 1
+ return i
+
+ def _find_all_type_decls(src: str) -> List[Dict]:
+ """
+ Return ordered list of type decl dicts with fields:
+ kind, name, header (text up to '{'), hstart, bstart, bend, enclosing (names list).
+ Uses a single pass nesting stack to compute 'enclosing' chain.
+ """
+ s = strip_comments_preserving_newlines(src)
+ decls: List[Dict] = []
+
+ # First pass: collect raw declarations with their body ranges
+ for m in TYPE_DECL_RE.finditer(s):
+ kind, name = m.group(1), m.group(2)
+ brace = s.find("{", m.end())
+ if brace == -1:
+ continue # malformed; ignore
+ bend = _balance_body_end(s, brace)
+ header_text = s[m.start():brace]
+ decls.append({
+ "kind": kind,
+ "name": name,
+ "header": re.sub(r"\s+", " ", header_text).strip(),
+ "hstart": m.start(),
+ "bstart": brace,
+ "bend": bend,
+ })
+
+ if not decls:
+ return []
+
+ # Second pass: assign enclosing chain using a simple interval stack
+ decls.sort(key=lambda d: d["hstart"])
+ stack: List[Dict] = []
+ for d in decls:
+ while stack and d["hstart"] >= stack[-1]["bend"]:
+ stack.pop()
+ d["enclosing"] = [x["name"] for x in stack]
+ stack.append(d)
+
+ return decls
def strip_comments_preserving_newlines(s: str) -> str:
"""Remove /* ... */ and // ... comments, preserving newlines for positions."""
def _blk(m):
@@ -1817,7 +1919,7 @@ def tup_dedup_keep_order(pairs: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
def meta_source(doc) -> str:
try:
meta = getattr(doc, "metadata", None) or {}
- return meta.get("source") or meta.get("path") or ""
+ return meta.get("source") or ""
except Exception:
return ""
@@ -1960,75 +2062,98 @@ def source_jar_id(path: str) -> str:
for doc in tqdm(java_types, total=len(java_types), desc="Building types inheritance map..."):
src_all = doc.page_content or ""
- # Prefer 'Code for:' header if present
- idx = src_all.lower().find("code for:")
- if idx != -1:
- line_end = src_all.find('\n', idx)
- if line_end == -1:
- line_end = len(src_all)
- header_line = src_all[idx:line_end]
- header_text_raw = re.sub(r".*?Code\s*for\s*:\s*", "", header_line, flags=re.I).strip()
- else:
- header_text_raw = find_header_slice_no_comments(src_all) or ""
-
- if not header_text_raw:
- header_text_raw = drop_leading_annotations(before_brace(src_all))
-
- header_text = before_brace(header_text_raw)
- header_text = re.sub(r"\s+", " ", header_text).strip()
-
- m_name = name_kind_re.search(header_text)
-
+ # Parse package/imports once per compilation unit
pkg = (pkg_re.search(src_all) or [None, ""])[1]
imps = import_re.findall(src_all)
src_path = meta_source(doc)
+ file_simple = simple_name_from_filename(doc) # e.g., "FormEncodedDataDefinition"
+ s_nc = strip_comments_preserving_newlines(src_all)
- if not m_name:
- # fallback to filename
- file_simple = simple_name_from_filename(doc)
+ # Discover ALL type declarations (top-level + nested)
+ decls = _find_all_type_decls(src_all)
+
+ # Fallback: no decls found → keep legacy behavior (top-level by filename)
+ if not decls:
if not file_simple:
continue
- fqcn = f"{pkg}.{file_simple}" if pkg else file_simple
kind = "class"
extends_names: List[str] = []
implements_names: List[str] = []
- parsed_name = file_simple
- else:
- kind, parsed_name = m_name.group(1), m_name.group(2)
+ declared_simple = file_simple
+ fqcn = f"{pkg}.{declared_simple}" if pkg else declared_simple
+
+ rec = {
+ "fqcn": fqcn,
+ "simple": declared_simple,
+ "pkg": pkg,
+ "imports": imps,
+ "kind": kind,
+ "extends_raw": extends_names,
+ "implements_raw": implements_names,
+ "source": src_path,
+ "jar": source_jar_id(src_path),
+ }
+ variants_by_fq.setdefault(fqcn, []).append(rec)
+ known_fqcns.add(fqcn)
+ continue
+
+ # If we're processing a *slice* that only contains an inner type, we may not see the outer.
+ # Detect presence of the outer by name in this source (to avoid mislabeling top-level siblings).
+ has_outer_decl = bool(file_simple and re.search(
+ rf"\b(class|interface|enum|record)\b\s+{re.escape(file_simple)}\b", s_nc))
+
+ # Build records for every discovered declaration
+ for d in decls:
+ kind = d["kind"]
+ declared_simple = d["name"]
+ header_text = d["header"]
+
+ # Extract extends/implements from the specific header
ext_seg, impl_seg = extract_clauses_after_kindname(header_text)
- # LAST RESORT (classes only): scan de-commented full source for "class ... extends ..."
- if kind == "class" and (not ext_seg) and ("extends" in src_all):
- s_nc = strip_comments_preserving_newlines(src_all)
+ # LAST RESORT (classes only): scan for "class ... extends ..." in de-commented source slice
+ if kind == "class" and (not ext_seg) and ("extends" in s_nc):
m_ext2 = re.search(
- rf"\bclass\s+{re.escape(parsed_name)}\b[^{{;]*?\bextends\b\s+([^{{;]+)",
+ rf"\bclass\s+{re.escape(declared_simple)}\b[^{{;]*?\bextends\b\s+([^{{;]+)",
s_nc
)
if m_ext2:
ext_seg = m_ext2.group(1).strip()
+
extends_names = split_type_list_strict(ext_seg) if ext_seg else []
implements_names = split_type_list_strict(impl_seg) if impl_seg else []
- # Defensive cleanup to avoid generic bounds leaking as parents/interfaces
- extends_names = [clean_type_token(x) for x in extends_names if clean_type_token(x)]
- implements_names = [clean_type_token(x) for x in implements_names if clean_type_token(x)]
-
- file_simple = simple_name_from_filename(doc)
- simple_name = file_simple or parsed_name
- fqcn = f"{pkg}.{simple_name}" if pkg else simple_name
-
- rec = {
- "fqcn": fqcn,
- "simple": simple_name,
- "pkg": pkg,
- "imports": imps,
- "kind": kind,
- "extends_raw": extends_names,
- "implements_raw": implements_names,
- "source": src_path,
- "jar": source_jar_id(src_path),
- }
- variants_by_fq.setdefault(fqcn, []).append(rec)
- known_fqcns.add(fqcn)
+ # Defensive cleanup
+ extends_names = [clean_type_token(x) for x in extends_names if clean_type_token(x)]
+ implements_names = [clean_type_token(x) for x in implements_names if clean_type_token(x)]
+
+ # Compute FQCN with correct nesting
+ if d["enclosing"]:
+ # Full chain from discovered parents: pkg.Outer.Inner.Current
+ prefix = ".".join(d["enclosing"])
+ fqcn = f"{pkg}.{prefix}.{declared_simple}" if pkg else f"{prefix}.{declared_simple}"
+ else:
+ # No enclosing type discovered.
+ # If this looks like an inner slice (no outer decl visible in this source)
+ # and filename suggests the real outer, prefix with file_simple.
+ if file_simple and declared_simple != file_simple and not has_outer_decl:
+ outer_fq = f"{pkg}.{file_simple}" if pkg else file_simple
+ fqcn = f"{outer_fq}.{declared_simple}"
+ else:
+ fqcn = f"{pkg}.{declared_simple}" if pkg else declared_simple
+
+ rec = {
+ "fqcn": fqcn,
+ "simple": declared_simple,
+ "pkg": pkg,
+ "imports": imps,
+ "kind": kind,
+ "extends_raw": extends_names,
+ "implements_raw": implements_names,
+ "source": src_path,
+ "jar": source_jar_id(src_path),
+ }
+ variants_by_fq.setdefault(fqcn, []).append(rec)
+ known_fqcns.add(fqcn)
# simple -> known FQCNs
by_simple: Dict[str, List[str]] = {}
@@ -2038,7 +2163,13 @@ def source_jar_id(path: str) -> str:
def make_resolver(pkg: str, imports: List[str]):
"""
Build a simple type resolver for a compilation unit.
- See top-level docstring for resolution precedence.
+ Resolution precedence:
+ 1) explicit imports
+ 2) same package
+ 3) star-import packages (unique)
+ 4) globally unique simple name
+ Also handles partially-qualified nested types like
+ 'FormParserFactory.ParserDefinition' → 'io.undertow.server.handlers.form.FormParserFactory.ParserDefinition'.
"""
explicit: Dict[str, str] = {}
star_import_packages: List[str] = []
@@ -2048,31 +2179,46 @@ def make_resolver(pkg: str, imports: List[str]):
else:
explicit[simple_name_from_fqcn(imp)] = imp # simple -> explicit FQCN
- def resolve(token: str) -> str:
- tok = base_token(token)
- if not tok:
- return "" # empty token
- if '.' in tok:
- return tok
- if tok in explicit:
- return explicit[tok]
+ def _resolve_head(head_simple: str) -> str | None:
+ """Resolve a top-level simple name to FQCN using the precedence above."""
+ if head_simple in explicit:
+ return explicit[head_simple]
if pkg:
- candidate = f"{pkg}.{tok}"
+ candidate = f"{pkg}.{head_simple}"
if candidate in known_fqcns:
return candidate
# Star imports – accept only a unique hit
candidates: List[str] = []
- for fqcn in by_simple.get(tok, []):
- if any(fqcn.startswith(base + ".") for base in star_import_packages):
- candidates.append(fqcn)
+ for fq_top in by_simple.get(head_simple, []):
+ if any(fq_top.startswith(base + ".") for base in star_import_packages):
+ candidates.append(fq_top)
if len(candidates) == 1:
return candidates[0]
# Globally unique simple name across all known types
- all_candidates = by_simple.get(tok, [])
+ all_candidates = by_simple.get(head_simple, [])
if len(all_candidates) == 1:
return all_candidates[0]
- # Unknown or ambiguous → keep simple name
- return tok
+ return None
+
+ def resolve(token: str) -> str:
+ tok = base_token(token).strip()
+ if not tok:
+ return "" # empty token
+
+ # Handle partially-qualified nested types, e.g., "FormParserFactory.ParserDefinition"
+ if "." in tok:
+ head, rest = tok.split(".", 1)
+ # If already looks fully-qualified for a known top-level type in this package,
+ # prefer composing with the package (avoids leaving it partially-qualified).
+ fq_head = _resolve_head(head)
+ if fq_head:
+ return fq_head + "." + rest
+ # If we can't resolve the head, leave as-is (may already be a fully-qualified package path)
+ return tok
+
+ # Simple (non-dotted) type name
+ fq_simple = _resolve_head(tok)
+ return fq_simple if fq_simple else tok
return resolve
@@ -2143,7 +2289,7 @@ def extends_chain_pairs(start_vid: Tuple[str, str]) -> List[Tuple[str, str]]:
"""
Upward chain for classes (linear) and interface parents (DFS),
jar-aware; placeholders where parents are unseen.
- Returns list of (simple, source).
+ Returns list of (fqcn, source).
"""
if start_vid in extends_chain_pairs_memo:
return extends_chain_pairs_memo[start_vid]
@@ -2172,9 +2318,10 @@ def dfs_iface(v: Tuple[str, str]):
seen_vids.add(parent_vid)
if is_interface_record(parent_rec):
dfs_iface(parent_vid)
- chain.append((parent_rec["simple"], parent_rec["source"]))
+ chain.append((parent_rec["fqcn"], parent_rec["source"]))
else:
- chain.append((simple_name_from_fqcn(parent_fqcn), cur["source"]))
+ # Keep the best-available token (may be FQCN or simple if unresolved)
+ chain.append((parent_fqcn, cur["source"]))
dfs_iface(start_vid)
out = tup_dedup_keep_order(chain)
@@ -2198,13 +2345,13 @@ def dfs_iface(v: Tuple[str, str]):
if parent_vid in visited_vids or parent_vid == cur_vid:
break
visited_vids.add(parent_vid)
- chain.append((parent_rec["simple"], parent_rec["source"]))
+ chain.append((parent_rec["fqcn"], parent_rec["source"]))
if not is_interface_record(parent_rec):
cur_vid = parent_vid
else:
break
else:
- chain.append((simple_name_from_fqcn(parent_fqcn), r["source"]))
+ chain.append((parent_fqcn, r["source"]))
break
out = tup_dedup_keep_order(chain)
@@ -2295,7 +2442,7 @@ def dfs_iface_vid(v: Tuple[str, str]):
def iface_super_pairs(iface_vid: Tuple[str, str]) -> List[Tuple[str, str]]:
"""
For an interface variant vid, return all transitive super-interfaces
- as (simple, source) pairs, jar-aware. Memoized.
+ as (fqcn, source) pairs, jar-aware. Memoized.
"""
if iface_vid in iface_super_pairs_memo:
return iface_super_pairs_memo[iface_vid]
@@ -2329,10 +2476,10 @@ def dfs_iface_vid(v: Tuple[str, str]):
if parent_rec:
pk = (parent_rec["fqcn"], parent_rec["source"])
dfs_iface_vid(pk)
- add_pair((parent_rec["simple"], parent_rec["source"]))
+ add_pair((parent_rec["fqcn"], parent_rec["source"]))
else:
# unknown external interface -> placeholder bound to current interface's source
- add_pair((simple_name_from_fqcn(parent_fqcn), cur_src))
+ add_pair((parent_fqcn, cur_src))
visiting.remove(v)
dfs_iface_vid(iface_vid)
@@ -2349,6 +2496,7 @@ def all_interfaces_pairs(start_vid: Tuple[str, str]) -> List[Tuple[str, str]]:
+ interfaces from ancestor classes (and their super-interfaces).
For an interface: its super-interfaces.
Jar-aware; placeholders only for interfaces not present in provided set.
+ Returns (fqcn, source) pairs.
"""
if start_vid in all_ifaces_pairs_memo:
return all_ifaces_pairs_memo[start_vid]
@@ -2374,12 +2522,12 @@ def add_iface_and_supers(iface_fqcn: str):
iface_rec = pick_variant_for_jar(iface_fqcn, jar_id)
if iface_rec:
ik = (iface_rec["fqcn"], iface_rec["source"])
- add_pair((iface_rec["simple"], iface_rec["source"]))
+ add_pair((iface_rec["fqcn"], iface_rec["source"]))
for sup in iface_super_pairs(ik):
add_pair(sup)
else:
# external/unseen -> placeholder (cannot know its supers)
- add_pair((simple_name_from_fqcn(iface_fqcn), cur_src))
+ add_pair((iface_fqcn, cur_src))
# Direct implements on this type
for iface_fqcn in rec.get("impls", []):
@@ -2467,7 +2615,7 @@ def all_implementers_of_interface_variant(iface_vid: Tuple[str, str]) -> List[Tu
out: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
for vid, rec in rec_by_vid.items():
- self_name = rec["simple"]
+ self_fqcn = rec["fqcn"]
source = rec.get("source", "")
up_chain_pairs = extends_chain_pairs(vid) # may include placeholders
@@ -2477,16 +2625,16 @@ def all_implementers_of_interface_variant(iface_vid: Tuple[str, str]) -> List[Tu
if is_interface_record(rec):
impl_vids = all_implementers_of_interface_variant(vid) # vids only
down_pairs.extend(
- (rec_by_vid[v]["simple"], rec_by_vid[v]["source"]) for v in impl_vids
+ (rec_by_vid[v]["fqcn"], rec_by_vid[v]["source"]) for v in impl_vids
)
else:
subs_vids = all_subclasses_variant(vid)
down_pairs.extend(
- (rec_by_vid[v]["simple"], rec_by_vid[v]["source"]) for v in subs_vids
+ (rec_by_vid[v]["fqcn"], rec_by_vid[v]["source"]) for v in subs_vids
)
- vals = tup_dedup_keep_order([(self_name, source)] + up_chain_pairs + iface_pairs + down_pairs)
- out[(self_name, source)] = vals
+ vals = tup_dedup_keep_order([(self_fqcn, source)] + up_chain_pairs + iface_pairs + down_pairs)
+ out[(self_fqcn, source)] = vals
return out
@@ -2754,4 +2902,42 @@ def strip_java_generics(type_str: str) -> str:
res = _PUNCT_SPACE_RE.sub(r'\1', res) # "Map []" -> "Map[]", "Map . Entry" -> "Map.Entry"
res = _VARARGS_SPACE_RE.sub(r'\1', res) # "List ..." -> "List..."
- return res
\ No newline at end of file
+ return res
+
+def extract_fqcn(text: str) -> str:
+ """
+ Convert a source file path to an FQCN.
+ Supported forms:
+ 1) dependencies-sources/--sources//.java
+ 2) .../src/main/java//.java
+ 3) /.java (no prefix)
+
+ Examples:
+ 'dependencies-sources/hibernate-core-6.6.13.Final-sources/org/hibernate/type/descriptor/java/ArrayJavaType.java'
+ -> 'org.hibernate.type.descriptor.java.ArrayJavaType'
+ 'org/hibernate/type/descriptor/java/ArrayJavaType.java'
+ -> 'org.hibernate.type.descriptor.java.ArrayJavaType'
+ '.../src/main/java/org/keycloak/.../VerifyEmailActionTokenHandler.java'
+ -> 'org.keycloak....VerifyEmailActionTokenHandler'
+ """
+ p = (text or "").replace("\\", "/").strip()
+
+ # Prefer the standard Maven/Gradle source-root marker when present.
+ src_marker = "/src/main/java/"
+ if src_marker in p:
+ tail = p.split(src_marker, 1)[1]
+ else:
+ # Otherwise, if there's a '-sources/' prefix, drop everything up to and including
+ # the last occurrence (handles dependencies-sources/...-sources/...).
+ marker = "-sources/"
+ cut = p.rfind(marker)
+ tail = p[cut + len(marker):] if cut != -1 else p
+
+ if tail.startswith("/"):
+ tail = tail[1:]
+
+ # Drop the .java suffix (case-sensitive per your examples)
+ if tail.endswith(".java"):
+ tail = tail[:-5]
+
+ return tail.replace("/", ".")
\ No newline at end of file
diff --git a/src/vuln_analysis/configs/config-http-nim.yml b/src/vuln_analysis/configs/config-http-nim.yml
index c613e686..d6321ae6 100644
--- a/src/vuln_analysis/configs/config-http-nim.yml
+++ b/src/vuln_analysis/configs/config-http-nim.yml
@@ -56,6 +56,8 @@ functions:
enable_functions_usage_search: true
Function Locator:
_type: package_and_function_locator
+ Function Library Version Finder:
+ _type: calling_function_library_version_finder
Code Semantic Search:
_type: local_vdb_retriever
embedder_name: nim_embedder
@@ -87,6 +89,7 @@ functions:
- Call Chain Analyzer
- Function Caller Finder
- Function Locator
+ - Function Library Version Finder
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/configs/config-http-openai.yml b/src/vuln_analysis/configs/config-http-openai.yml
index 4e6ca79f..a67e18c1 100644
--- a/src/vuln_analysis/configs/config-http-openai.yml
+++ b/src/vuln_analysis/configs/config-http-openai.yml
@@ -63,6 +63,8 @@ functions:
enable_functions_usage_search: true
Function Locator:
_type: package_and_function_locator
+ Function Library Version Finder:
+ _type: calling_function_library_version_finder
Code Semantic Search:
_type: local_vdb_retriever
embedder_name: nim_embedder
@@ -94,6 +96,7 @@ functions:
- Call Chain Analyzer
- Function Caller Finder
- Function Locator
+ - Function Library Version Finder
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/configs/config-tracing.yml b/src/vuln_analysis/configs/config-tracing.yml
index 331289e9..397258e6 100644
--- a/src/vuln_analysis/configs/config-tracing.yml
+++ b/src/vuln_analysis/configs/config-tracing.yml
@@ -67,6 +67,8 @@ functions:
enable_functions_usage_search: true
Function Locator:
_type: package_and_function_locator
+ Function Library Version Finder:
+ _type: calling_function_library_version_finder
Code Semantic Search:
_type: local_vdb_retriever
embedder_name: nim_embedder
@@ -98,6 +100,7 @@ functions:
- Call Chain Analyzer
- Function Caller Finder
- Function Locator
+ - Function Library Version Finder
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/configs/config.yml b/src/vuln_analysis/configs/config.yml
index 7916989c..9484f646 100644
--- a/src/vuln_analysis/configs/config.yml
+++ b/src/vuln_analysis/configs/config.yml
@@ -60,6 +60,8 @@ functions:
max_retries: 5
Container Analysis Data:
_type: container_image_analysis_data
+ Function Library Version Finder:
+ _type: calling_function_library_version_finder
cve_agent_executor:
_type: cve_agent_executor
llm_name: cve_agent_executor_llm
@@ -68,6 +70,7 @@ functions:
- Docs Semantic Search
# - Code Keyword Search # Uncomment to enable keyword search
- CVE Web Search
+ - Function Library Version Finder
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py
index 4f1e168d..42345a74 100644
--- a/src/vuln_analysis/functions/cve_agent.py
+++ b/src/vuln_analysis/functions/cve_agent.py
@@ -14,6 +14,8 @@
# limitations under the License.
import asyncio
+
+from exploit_iq_commons.utils.dep_tree import Ecosystem
from vuln_analysis.runtime_context import ctx_state
import typing
from aiq.builder.builder import Builder
@@ -86,10 +88,17 @@ async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder,
(tool.name == ToolNames.FUNCTION_CALLER_FINDER and (not config.transitive_search_tool_enabled or
state.code_index_path is None)) or
(tool.name == ToolNames.FUNCTION_LOCATOR and (not config.transitive_search_tool_enabled or
+ state.code_index_path is None)) or
+ (tool.name == ToolNames.FUNCTION_LIBRARY_VERSION_FINDER and (not config.transitive_search_tool_enabled or
state.code_index_path is None))
)
]
-
+
+ tools = [
+ tool for tool in tools
+ if not ((tool.name == ToolNames.FUNCTION_CALLER_FINDER and state.original_input.input.image.ecosystem != Ecosystem.GO) or
+ (tool.name == ToolNames.FUNCTION_LIBRARY_VERSION_FINDER and state.original_input.input.image.ecosystem != Ecosystem.JAVA))
+ ]
# Get tool names after filtering for dynamic guidance
enabled_tool_names = [tool.name for tool in tools]
diff --git a/src/vuln_analysis/functions/cve_checklist.py b/src/vuln_analysis/functions/cve_checklist.py
index 696c7fb9..73b11fed 100644
--- a/src/vuln_analysis/functions/cve_checklist.py
+++ b/src/vuln_analysis/functions/cve_checklist.py
@@ -58,13 +58,14 @@ async def cve_checklist(config: CVEChecklistToolConfig, builder: Builder):
agent_config = builder.get_function_config(config.agent_name)
agent_tool_names = agent_config.tool_names if hasattr(agent_config, 'tool_names') else None
- async def generate_checklist_for_cve(cve_intel):
+ async def generate_checklist_for_cve(cve_intel, ecosystem=None):
checklist = await generate_checklist(prompt=config.prompt,
llm=llm,
input_dict=cve_intel,
tool_names=agent_tool_names,
- enable_llm_list_parsing=False)
+ enable_llm_list_parsing=False,
+ ecosystem=ecosystem)
checklist = await _parse_list([checklist])
@@ -75,7 +76,15 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState:
intel_df = data_utils.merge_intel_and_plugin_data_convert_to_dataframe(state.cve_intel)
workflow_cve_intel = intel_df.to_dict(orient='records')
- results = await asyncio.gather(*(generate_checklist_for_cve(cve_intel) for cve_intel in workflow_cve_intel))
+ # Extract ecosystem for ecosystem-aware example selection
+ ecosystem = None
+ if (state.original_input and state.original_input.input
+ and state.original_input.input.image
+ and state.original_input.input.image.ecosystem):
+ ecosystem = state.original_input.input.image.ecosystem.value
+
+ results = await asyncio.gather(*(generate_checklist_for_cve(cve_intel, ecosystem=ecosystem)
+ for cve_intel in workflow_cve_intel))
state.checklist_plans = dict(results)
return state
diff --git a/src/vuln_analysis/tools/serp.py b/src/vuln_analysis/tools/serp.py
index 7fe820b8..c208c46b 100644
--- a/src/vuln_analysis/tools/serp.py
+++ b/src/vuln_analysis/tools/serp.py
@@ -48,6 +48,9 @@ async def _arun(query: str) -> str:
_arun,
description=(
"Searches the web for information about CVEs, vulnerabilities, libraries, "
- "and security advisories not available in the container."
+ "and security advisories not available in the container. "
+ "WARNING: Results may contain information for different library versions. "
+ "Always verify version-specific claims against the actual installed version. "
+ "Include the library version in your search query when possible for more precise results."
)
- )
+ )
\ No newline at end of file
diff --git a/src/vuln_analysis/tools/tool_names.py b/src/vuln_analysis/tools/tool_names.py
index 248fec7f..bded2160 100644
--- a/src/vuln_analysis/tools/tool_names.py
+++ b/src/vuln_analysis/tools/tool_names.py
@@ -44,6 +44,9 @@ class ToolNames:
CONTAINER_ANALYSIS_DATA = "Container Analysis Data"
"""Retrieves pre-analyzed data from earlier container scan steps"""
+ FUNCTION_LIBRARY_VERSION_FINDER = "Function Library Version Finder"
+ """Checks in which library version the function is used"""
+
# Export as module-level constants
CODE_SEMANTIC_SEARCH = ToolNames.CODE_SEMANTIC_SEARCH
@@ -54,6 +57,7 @@ class ToolNames:
FUNCTION_CALLER_FINDER = ToolNames.FUNCTION_CALLER_FINDER
CVE_WEB_SEARCH = ToolNames.CVE_WEB_SEARCH
CONTAINER_ANALYSIS_DATA = ToolNames.CONTAINER_ANALYSIS_DATA
+FUNCTION_LIBRARY_VERSION_FINDER = ToolNames.FUNCTION_LIBRARY_VERSION_FINDER
@@ -66,5 +70,5 @@ class ToolNames:
'FUNCTION_CALLER_FINDER',
'CVE_WEB_SEARCH',
'CONTAINER_ANALYSIS_DATA',
- 'FUNCTION_LOCATOR'
+ 'FUNCTION_LIBRARY_VERSION_FINDER',
]
diff --git a/src/vuln_analysis/tools/transitive_code_search.py b/src/vuln_analysis/tools/transitive_code_search.py
index 6106b317..acaee311 100644
--- a/src/vuln_analysis/tools/transitive_code_search.py
+++ b/src/vuln_analysis/tools/transitive_code_search.py
@@ -42,9 +42,31 @@
FUNCTION_NAME_EXTRACTOR_TOOL_NAME = "calling_function_name_extractor"
+FUNCTION_LIBRARY_VERSION_FINDER_TOOL_NAME = "calling_function_library_version_finder"
+
TRANSITIVE_CODE_SEARCH_TOOL_NAME = "transitive_code_search"
logger = LoggingFactory.get_agent_logger(__name__)
+_QUERY_FORMAT_ERROR = (
+ "Invalid input format. The query must contain exactly one comma separating package and function. "
+ "Expected format 1: 'package_name,function_name' (e.g., 'urllib,parse'). "
+ "Expected format 2 (java): 'maven_gav,class_name.function_name' or 'maven_gav,fqcn.function_name' (preferred) "
+ "(e.g., 'commons-beanutils:commons-beanutils:x.y.z,a.b.c.ClassA.foo'). "
+ "Each tool call must contain exactly ONE query. Do NOT combine multiple queries with 'and'. "
+ "Please retry with the correct format."
+)
+
+
+def _validate_query_format(query: str) -> tuple[bool, str]:
+ """Validate that query matches expected format and return (is_valid, error_message_or_cleaned_query)."""
+ cleaned = query.strip().split("\n")[0].strip().strip("'\"\u2018\u2019\u201c\u201d").strip()
+ # Remove surrounding backticks if present
+ cleaned = cleaned.strip("`")
+ parts = cleaned.split(",")
+ if len(parts) != 2 or not parts[0].strip() or not parts[1].strip():
+ return False, f"{_QUERY_FORMAT_ERROR} Received: '{query}'"
+ return True, cleaned
+
class TransitiveCodeSearchToolConfig(FunctionBaseConfig, name=("%s" % TRANSITIVE_CODE_SEARCH_TOOL_NAME)):
"""
@@ -63,6 +85,11 @@ class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name=("%s" % PACKA
Package and function locator tool used to validate package names and find function names using fuzzy matching.
"""
+class FunctionLibraryVersionFinderToolConfig(FunctionBaseConfig, name=FUNCTION_LIBRARY_VERSION_FINDER_TOOL_NAME):
+ """
+ Checks in which library version the function is used.
+ """
+
def get_call_of_chains_retriever(documents_embedder, si, query: str):
documents: list[Document]
git_repo = None
@@ -108,15 +135,18 @@ def get_transitive_code_searcher(query: str):
@register_function(config_type=TransitiveCodeSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
async def transitive_search(config: TransitiveCodeSearchToolConfig,
- builder: Builder): # pylint: disable=unused-argument
+ builder: Builder, verbose_mode: bool = False): # pylint: disable=unused-argument
"""
Call Chain Analyzer tool used to search source code function reachability.
"""
@catch_tool_errors(TRANSITIVE_CODE_SEARCH_TOOL_NAME)
async def _arun(query: str) -> tuple:
+ is_valid, validation_result = _validate_query_format(query)
+ if not is_valid:
+ return False, [validation_result]
transitive_code_searcher: TransitiveCodeSearcher
- transitive_code_searcher = get_transitive_code_searcher(query)
- result = transitive_code_searcher.search(query)
+ transitive_code_searcher = get_transitive_code_searcher(validation_result)
+ result = transitive_code_searcher.search(validation_result)
return result
yield FunctionInfo.from_fn(
@@ -128,9 +158,9 @@ async def _arun(query: str) -> tuple:
Input format 1: 'package_name,function_name'.
Example 1: 'urllib,parse'.
- Input format 2(java): 'maven_gav,class_name.function_name'.
- Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.
-
+ Input format 2(java): 'maven_gav,class_name.function_name' or 'maven_gav,fqcn.function_name' (preferred).
+ Example 2(java): 'commons-beanutils:commons-beanutils:x.y.z,a.b.c.ClassA.foo'.
+
Returns: (is_reachable: bool, call_hierarchy_path: list).
"""))
@@ -172,12 +202,15 @@ async def package_and_function_locator(config: PackageAndFunctionLocatorToolConf
@catch_tool_errors(PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME)
async def _arun(query: str) -> dict:
+ is_valid, validation_result = _validate_query_format(query)
+ if not is_valid:
+ return {"error": validation_result}
coc_retriever: ChainOfCallsRetrieverBase
transitive_code_searcher: TransitiveCodeSearcher
- transitive_code_searcher = get_transitive_code_searcher(query)
+ transitive_code_searcher = get_transitive_code_searcher(validation_result)
coc_retriever = transitive_code_searcher.chain_of_calls_retriever
locator = FunctionNameLocator(coc_retriever)
- result = await locator.locate_functions(query)
+ result = await locator.locate_functions(validation_result)
pkg_msg = "Package is valid."
if not locator.is_package_valid and not locator.is_std_package:
pkg_msg = "Package is not valid."
@@ -197,8 +230,67 @@ async def _arun(query: str) -> dict:
Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.
Example 1: 'libxml2,xmlParseDocument'.
- Input format 2(java): 'maven_gav,class_name.method_name'.
- Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.
+ Input format 2(java): 'maven_gav,class_name.method_name' or 'maven_gav,fqcn.method_name' (preferred).
+ Example 2(java): 'commons-beanutils:commons-beanutils:x.y.z,a.b.c.ClassA.foo'.
Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.
"""))
+
+
+@register_function(config_type=FunctionLibraryVersionFinderToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
+async def library_version_finder(config: FunctionLibraryVersionFinderToolConfig,
+ builder: Builder): # pylint: disable=unused-argument
+ """
+ Function Library Version Finder tool used to check which version of a package/library
+ is installed in the application's dependency tree.
+ """
+
+ @catch_tool_errors(FUNCTION_LIBRARY_VERSION_FINDER_TOOL_NAME)
+ async def _arun(query: str) -> dict:
+ transitive_code_searcher = get_transitive_code_searcher(query)
+ coc_retriever = transitive_code_searcher.chain_of_calls_retriever
+
+ # Clean the query: strip whitespace, trailing junk after newlines, then quotes (including unicode smart quotes)
+ cleaned_query = query.strip().split("\n")[0].strip().strip("'\"\u2018\u2019\u201c\u201d").strip()
+ # Search for matching packages in the dependency tree
+ search_term = cleaned_query.lower()
+ matching_packages = []
+ for package in coc_retriever.supported_packages:
+ # Match against full GAV or individual segments
+ # Java GAV format: "groupId:artifactId:version"
+ package_lower = package.lower()
+ parts = package_lower.split(":")
+ if (search_term in package_lower or
+ any(search_term == part for part in parts) or
+ any(search_term in part for part in parts)):
+ matching_packages.append(package)
+
+ if not matching_packages:
+ return {
+ "ecosystem": coc_retriever.ecosystem.name,
+ "found": False,
+ "message": f"No package matching '{cleaned_query}' found in the dependency tree.",
+ "matching_packages": []
+ }
+
+ return {
+ "ecosystem": coc_retriever.ecosystem.name,
+ "found": True,
+ "message": f"Found {len(matching_packages)} matching package(s) in the dependency tree.",
+ "matching_packages": matching_packages
+ }
+
+ yield FunctionInfo.from_fn(
+ _arun,
+ description=("""
+ Finds the exact version of a library/package installed in the application's dependency tree.
+ Use this tool to determine the EXACT version of a library before interpreting version-specific
+ information from CVE Web Search results.
+
+ Input: library or package name to search for.
+ Example (java): 'commons-beanutils'
+ Example (python): 'requests'
+
+ Returns: {'ecosystem': str, 'found': bool, 'message': str, 'matching_packages': list}.
+ For Java, matching_packages contains Maven GAV coordinates like 'groupId:artifactId:version'.
+"""))
diff --git a/src/vuln_analysis/utils/checklist_prompt_generator.py b/src/vuln_analysis/utils/checklist_prompt_generator.py
index 3ed21b15..0226dde9 100644
--- a/src/vuln_analysis/utils/checklist_prompt_generator.py
+++ b/src/vuln_analysis/utils/checklist_prompt_generator.py
@@ -115,12 +115,18 @@ async def generate_checklist(prompt: str | None,
llm: BaseLanguageModel,
input_dict: dict,
tool_names: list[str] | None = None,
- enable_llm_list_parsing: bool = False) -> str:
-
+ enable_llm_list_parsing: bool = False,
+ ecosystem: str | None = None) -> str:
+
from vuln_analysis.utils.prompting import build_tool_descriptions
-
+
if not prompt:
- prompt = DEFAULT_CHECKLIST_PROMPT
+ if ecosystem:
+ # Build ecosystem-aware prompt with relevant few-shot examples
+ _escaped = MOD_FEW_SHOT.replace('{tool_descriptions}', '{{tool_descriptions}}')
+ prompt = _escaped.format(examples=get_mod_examples(ecosystem=ecosystem))
+ else:
+ prompt = DEFAULT_CHECKLIST_PROMPT
# Build tool descriptions with checklist-specific formatting
if tool_names:
@@ -145,12 +151,29 @@ async def generate_checklist(prompt: str | None,
"\n"
"\n\n"
"\n- If CVE describes a vulnerable function/method, first checklist item MUST "
- "check if code calls it"
- "\n- Vulnerable package version is already confirmed installed; focus on other "
- "exploitability factors"
+ "check if the EXACT vulnerable method (e.g., ClassName.methodName) is called "
+ "or reachable — not just the containing class or module"
+ "\n- If the CVE does NOT name a specific vulnerable function but describes a "
+ "vulnerability in a library's core functionality (e.g., deserialization, parsing, "
+ "encoding), the first checklist item MUST verify whether the library's primary "
+ "entry-point methods (e.g., ObjectMapper.readValue for jackson-databind, "
+ "XStream.fromXML for XStream) are used and reachable from application code "
+ "using the Call Chain Analyzer tool"
+ "\n- If CVE advisory data includes affected version ranges or fixed/patched versions "
+ "(e.g., GHSA vulnerable_version_range, NVD version configurations, RHSA package state), "
+ "include a checklist item to verify whether the installed version falls within the "
+ "vulnerable range using the available analysis tools"
+ "\n- Do not assume the installed version is vulnerable; verify it against advisory data"
"\n- Each item must be answerable with available analysis tools (code/doc search, "
"dependency checks)"
- "\n- Use specific technical names from CVE details (functions, components, configurations)"
+ "\n- Use EXACT technical names from CVE details (functions, components, configurations). "
+ "Do not confuse similar-sounding concepts — e.g., distinguish between different "
+ "properties, classes, or fix mechanisms mentioned in the advisory"
+ "\n- All questions must reference the specific vulnerable package from the CVE. "
+ "Do not investigate components from unrelated packages or libraries"
+ "\n- Focus on CODE REACHABILITY (function calls, imports, execution paths). "
+ "Configuration-level checks are secondary to verifying whether the vulnerable "
+ "code is actually called"
"\n- Maximum 5 checklist items; prioritize most critical exploitability checks"
"\n"
"\n\nGenerate checklist:"
diff --git a/src/vuln_analysis/utils/function_name_locator.py b/src/vuln_analysis/utils/function_name_locator.py
index 13f3adbc..40647965 100644
--- a/src/vuln_analysis/utils/function_name_locator.py
+++ b/src/vuln_analysis/utils/function_name_locator.py
@@ -195,20 +195,28 @@ async def locate_functions(self, query: str) -> list[str]:
return [
(
f"INFO: Package '{package}' is a standard library package. "
- f"make call with the Transitive code search tool"
+ f"Proceed by calling the Call Chain Analyzer tool with input: "
+ f"'{package},{function}'. "
+ f"Input format: 'package_name,function_name' "
+ f"or for Java: 'package_name,ClassName.methodName'."
)
]
else:
- is_standard_lib_api = await quick_standard_lib_check(package, self.coc_retriever.ecosystem)
- if is_standard_lib_api:
- self.is_std_package = True
- self.is_package_valid = True
- self.stdlib_cache.add_to_cache(package, self.coc_retriever.ecosystem)
- return [
- (
- f"INFO: Package '{package}' is a standard library package. "
- f"make call with the Transitive code search tool"
- )]
+ if (self.coc_retriever.ecosystem and
+ self.coc_retriever.ecosystem.value != Ecosystem.JAVA.value):
+ is_standard_lib_api = await quick_standard_lib_check(package, self.coc_retriever.ecosystem)
+ if is_standard_lib_api:
+ self.is_std_package = True
+ self.is_package_valid = True
+ self.stdlib_cache.add_to_cache(package, self.coc_retriever.ecosystem)
+ return [
+ (
+ f"INFO: Package '{package}' is a standard library package. "
+ f"Proceed by calling the Call Chain Analyzer tool with input: "
+ f"'{package},{function}'. "
+ f"Input format: 'package_name,function_name' "
+ f"or for Java: 'package_name,ClassName.methodName'."
+ )]
close_package_matches = self.handle_package_not_in_supported_packages(package)
if close_package_matches:
error_msg = (
diff --git a/src/vuln_analysis/utils/justification_parser.py b/src/vuln_analysis/utils/justification_parser.py
index ac94ed19..1ad8feef 100644
--- a/src/vuln_analysis/utils/justification_parser.py
+++ b/src/vuln_analysis/utils/justification_parser.py
@@ -38,11 +38,18 @@ class JustificationParser:
1. false_positive - CVE-to-package association is incorrect (wrong package or mismatched CVE)
-2. code_not_present - Vulnerable code/library is absent from the container
- (If code is not present, subsequent factors are irrelevant)
+2. code_not_present - Vulnerable package/library is not installed in the container,
+ OR the installed version is at or above the known fixed/patched version
+ from advisory data (the vulnerable code no longer exists in this version).
+ This includes cases where the package is entirely absent from the dependency tree.
+ If the CVE affects multiple packages and at least one IS present at a vulnerable
+ version, do NOT use this category — evaluate based on the present vulnerable package.
+ (If vulnerable code is not present, subsequent factors are irrelevant)
3. code_not_reachable - Vulnerable code exists but is never executed at runtime
- (Only applicable if code IS present but execution path analysis shows no calls)
+ (Only applicable if Call Chain Analyzer explicitly shows function is NOT
+ reachable. NEVER use this category when Call Chain Analyzer confirms the
+ function IS reachable from application code)
4. requires_configuration - Exploitation requires specific configuration that is disabled
(Configuration prevents exploitation)
@@ -70,6 +77,18 @@ class JustificationParser:
- Vulnerable code is REACHABLE from attack surfaces (user input, network, file processing)
- No effective mitigations or protections are in place
+REACHABILITY PRIORITY:
+When investigation evidence confirms the vulnerable package's functions are
+USED or REACHABLE from application code (via Call Chain Analyzer, Function
+Locator, or Code Keyword Search showing direct usage), classify as
+"vulnerable" unless the installed version is definitively NOT in the
+vulnerable range. Inability to confirm specific exploitation conditions
+(untrusted input, specific data patterns like deep nesting or recursive
+collections, etc.) does NOT warrant a non-vulnerable classification — static
+analysis cannot reliably determine these runtime conditions, and they should
+be assumed present when the vulnerable library's functionality is actively
+used in the codebase.
+
IF EXPLOITATION CONDITIONS ARE NOT MET:
Select the PRIMARY reason for non-exploitability following the logical precedence
order above. For example:
diff --git a/src/vuln_analysis/utils/prompting.py b/src/vuln_analysis/utils/prompting.py
index 45837e0c..f1276b3c 100644
--- a/src/vuln_analysis/utils/prompting.py
+++ b/src/vuln_analysis/utils/prompting.py
@@ -53,15 +53,19 @@ def build_tool_descriptions(tool_names: list[str]) -> list[str]:
f"{ToolNames.CODE_KEYWORD_SEARCH}: Exact text matching for function names, class names, or imports"
)
+ call_chain_format_note = (
+ "Input MUST be exactly ONE query per call: 'package,function' or 'maven_gav,Class.method' or 'maven_gav,fqcn.method' (java, fqcn preferred). "
+ "NEVER combine multiple queries with 'and' in a single input"
+ )
if ToolNames.FUNCTION_CALLER_FINDER in tool_names and ToolNames.CALL_CHAIN_ANALYZER in tool_names:
descriptions.append(
- f"{ToolNames.CALL_CHAIN_ANALYZER}: Checks if functions are reachable from application code\n"
+ f"{ToolNames.CALL_CHAIN_ANALYZER}: Checks if functions are reachable from application code. {call_chain_format_note}\n"
f"{ToolNames.FUNCTION_CALLER_FINDER}: Finds which functions call specific library functions\n"
f"Use '{ToolNames.FUNCTION_CALLER_FINDER}' + '{ToolNames.CALL_CHAIN_ANALYZER}' together to trace function reachability"
)
elif ToolNames.CALL_CHAIN_ANALYZER in tool_names:
descriptions.append(
- f"{ToolNames.CALL_CHAIN_ANALYZER}: Checks if functions are reachable from application code"
+ f"{ToolNames.CALL_CHAIN_ANALYZER}: Checks if functions are reachable from application code. {call_chain_format_note}"
)
elif ToolNames.FUNCTION_CALLER_FINDER in tool_names:
descriptions.append(
@@ -70,9 +74,17 @@ def build_tool_descriptions(tool_names: list[str]) -> list[str]:
if ToolNames.CVE_WEB_SEARCH in tool_names:
descriptions.append(
- f"{ToolNames.CVE_WEB_SEARCH}: External vulnerability information lookup"
+ f"{ToolNames.CVE_WEB_SEARCH}: External vulnerability information lookup. "
+ f"IMPORTANT: Search results may contain information for MULTIPLE library versions. "
+ f"Always use {ToolNames.FUNCTION_LIBRARY_VERSION_FINDER} first to determine the exact installed version, "
+ f"then cross-reference web search results against that version"
)
-
+
+ if ToolNames.FUNCTION_LIBRARY_VERSION_FINDER in tool_names:
+ descriptions.append(
+ f"{ToolNames.FUNCTION_LIBRARY_VERSION_FINDER}: Checks in which library version the function is used"
+ )
+
if ToolNames.CONTAINER_ANALYSIS_DATA in tool_names:
descriptions.append(
f"{ToolNames.CONTAINER_ANALYSIS_DATA}: Retrieves findings from earlier container scan analysis"
@@ -101,6 +113,36 @@ def build_tool_descriptions(tool_names: list[str]) -> list[str]:
- Connect findings to exploitability conditions
3. FOCUS: Use only definitive checklist results; ignore inconclusive items
+
+4. VERSION CHECK PRIORITY: If a checklist answer establishes that the installed
+ version of the vulnerable package is NOT within the vulnerable range (i.e.,
+ at or above the fixed/patched version), the CVE is NOT exploitable for that
+ package. However, when the CVE affects multiple packages, evaluate each
+ package independently — a non-vulnerable version of one package does NOT
+ override a vulnerable version of a different package.
+
+5. CONTRADICTIONS: If checklist answers contradict each other (e.g., one says a
+ function is not used while another says it is reachable), prefer the answer
+ backed by specific tool evidence (e.g., Call Chain Analyzer showing reachability)
+ over vague or negative conclusions
+
+6. TOOL EVIDENCE PRIORITY: When Call Chain Analyzer results consistently show
+ a function is NOT REACHABLE, conclude the CVE is not exploitable for that
+ code path. Do not override Call Chain Analyzer results with inferences from
+ Code Keyword Search alone — finding a class name or keyword in source code
+ does not prove the specific vulnerable method is called or reachable
+
+7. REACHABILITY IMPLIES EXPLOITABILITY: When investigation evidence confirms
+ that the vulnerable package's functions are USED or REACHABLE from
+ application code (via Call Chain Analyzer, Function Locator, or Code
+ Keyword Search showing direct usage of the library's entry-point methods),
+ conclude the CVE IS exploitable unless the installed version is definitively
+ NOT in the vulnerable range. Inability to confirm specific exploitation
+ conditions (e.g., whether untrusted input reaches the function, whether
+ specific data patterns like recursive collections or deep nesting are used)
+ does NOT override positive usage/reachability evidence — these conditions
+ are difficult to determine through static analysis and should be assumed
+ present when the vulnerable library's functionality is actively used
@@ -128,37 +170,37 @@ def build_tool_descriptions(tool_names: list[str]) -> list[str]:
"specialized search and analysis tools."
)
-AGENT_PROMPT_TEMPLATE = """
-Answer the investigation question using the available tools. If the input is not a question, formulate it into a question first. A Tool Selection Strategy is provided to help you decide which tools to use. Focus on answering the question. Include your intermediate reasoning in the final answer.
-
+AGENT_PROMPT_TEMPLATE = """## Task
+Answer the investigation question using the available tools. Include your intermediate reasoning in the final answer.
-
+## Available Tools
{tools}
-
-
+## Tool Selection Strategy
{tool_selection_strategy}
-
-
-Follow this format exactly (start each line with one of the specified prefixes):
+## Rules
+You have a LIMITED number of tool calls. Follow these rules strictly:
+
+1. ONE query per tool call. Never combine with 'and' or 'or'.
+2. For Java tools (Call Chain Analyzer, Function Locator), use the full Maven GAV: 'groupId:artifactId:version,fqcn.methodName' (e.g., 'commons-beanutils:commons-beanutils:1.9.4,org.apache.commons.beanutils.PropertyUtilsBean.getProperty'). For Go: 'module/path,FunctionName'. For Python/JS/C: 'package_name,function_name'. Always search for the EXACT vulnerable method, not just the containing class or module.
+3. Plan first. Identify the 2-3 most important checks before making tool calls. Do not repeat similar searches.
+4. VERSION CHECK: Use Function Library Version Finder to get the installed version, then compare it against advisory data (GHSA vulnerable_version_range, NVD configurations, etc.). If the installed version is at or above the fixed/patched version, the vulnerability is FIXED — conclude immediately.
+5. EVIDENCE-BASED ANSWERS ONLY: Only claim the application does something if code analysis tools (Function Locator, Call Chain Analyzer, Code Keyword Search) confirm it. CVE Web Search describes general vulnerability info, NOT this application. If a search returns no results, say "no evidence found" — do NOT generalize to "not used in the codebase."
+6. Conclude promptly. When you have enough evidence, go directly to Final Answer. Do NOT output "Action: None" — either use a real tool or provide your Final Answer.
+## Response Format
Question: the input question you must answer
Thought: think about what action to take next
Action: the tool to use, must be one of [{tool_names}]
Action Input: the specific input for that tool
Observation: the result returned by the tool
-... (repeat Thought/Action/Action Input/Observation cycle as many times as needed)
+... (repeat Thought/Action/Action Input/Observation as needed)
Thought: I now know the final answer
-Final Answer: provide the answer with supporting evidence from your investigation
-
-
-
-
+Final Answer: provide the answer with supporting evidence
-
+## Question
{input}
-
Begin your investigation:
Thought:{agent_scratchpad}"""
@@ -398,9 +440,8 @@ def get_agent_prompt(sys_prompt: str | None = None,
# Select template with or without examples
if prompt_examples:
prompt_template = AGENT_PROMPT_TEMPLATE.replace(
- "",
- "\n" + AGENT_EXAMPLES_FOR_PROMPT
- # AGENT_EXAMPLES_FOR_PROMPT_2
+ "## Question",
+ "## Examples\n" + AGENT_EXAMPLES_FOR_PROMPT + "\n## Question"
)
else:
prompt_template = AGENT_PROMPT_TEMPLATE
@@ -506,6 +547,9 @@ def build_prompt(self) -> str:
'Example {idx}: CVE Details:\n- CVE ID: CVE-2023-5363\n- CVE description: Issue summary: A bug has been identified in the processing of key and initialisation vector (IV) lengths. This can lead to potential truncation or overruns during the initialisation of some symmetric ciphers. Impact summary: A truncation in the IV can result in non-uniqueness, which could result in loss of confidentiality for some cipher modes. When calling EVP_EncryptInit_ex2(), EVP_DecryptInit_ex2() or EVP_CipherInit_ex2() the provided OSSL_PARAM array is processed after the key and IV have been established. Any alterations to the key length, via the "keylen" parameter or the IV length, via the "ivlen" parameter, within the OSSL_PARAM array will not take effect as intended, potentially causing truncation or overreading of these values. The following ciphers and cipher modes are impacted: RC2, RC4, RC5, CCM, GCM and OCB. For the CCM, GCM and OCB cipher modes, truncation of the IV can result in loss of confidentiality. For example, when following NIST\'s SP 800-38D section 8.2.1 guidance for constructing a deterministic IV for AES in GCM mode, truncation of the counter portion could lead to IV reuse. Both truncations and overruns of the key and overruns of the IV will produce incorrect results and could, in some cases, trigger a memory exception. However, these issues are not currently assessed as security critical.\nChanging the key and/or IV lengths is not considered to be a common operation and the vulnerable API was recently introduced. Furthermore it is likely that application developers will have spotted this problem during testing since decryption would fail unless both peers in the communication were similarly vulnerable. For these reasons we expect the probability of an application being vulnerable to this to be quite low. However if an application is vulnerable then this issue is considered very serious. For these reasons we have assessed this issue as Moderate severity overall.\nThe OpenSSL SSL/TLS implementation is not affected by this issue. The OpenSSL 3.0 and 3.1 FIPS providers are not affected by this because the issue lies outside of the FIPS provider boundary. OpenSSL 3.1 and 3.0 are vulnerable to this issue.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:N/A:N\n- Notable Vulnerable Software Vendors: [\'Debian\', \'Netapp\', \'Openssl\']\n\nExample {idx}: Checklist:\n[\n\t"Check OpenSSL Version: What version of OpenSSL is running in the container image? The vulnerability specifically affects OpenSSL versions 3.0 and 3.1. Is the version running in the container within the vulnerability range? If the container is running these versions, it may be vulnerable.",\n\t"Identify Affected Cipher Modes: Does the application within the container image use any of the affected cipher modes: RC2, RC4, RC5, CCM, GCM, or OCB? Special attention should be given to applications using CCM, GCM, and OCB modes as these are particularly noted for potential loss of confidentiality due to IV truncation.",\n\t"Review Cryptographic Operations: Does the code or configuration of applications using OpenSSL have any instances where `EVP_EncryptInit_ex2()`, `EVP_DecryptInit_ex2()`, or `EVP_CipherInit_ex2()` are called? Are there any modifications to the `keylen` or `ivlen` parameters after initialization which might not be taking effect as intended?",\n\t"Check for Custom Cryptographic Implementations: Since changing the key and/or IV lengths is not a common operation and the issue is in a recently introduced API, it\'s crucial to identify if any custom cryptographic implementations might be performing such operations. This is less likely but should be checked especially in bespoke or highly customized applications. Are there any custom cryptographic implemenations changing the key and/or IV lengths?"\n]',
'Example {idx}: CVE Details:\n- CVE ID: CVE-2024-2961\n- CVE description: The iconv() function in the GNU C Library versions 2.39 and older may overflow the output buffer passed to it by up to 4 bytes when converting strings to the ISO-2022-CN-EXT character set, which may be used to crash an application or overwrite a neighbouring variable.\n- CWE Name: CWE-787: Out-of-bounds Write (4.14)\n- CWE Description: The product writes data past the end, or before the beginning, of the intended buffer.\nTypically, this can result in corruption of data, a crash, or code execution. The product may modify an index or perform pointer arithmetic that references a memory location that is outside of the boundaries of the buffer. A subsequent write operation then produces undefined or unexpected results.\n- Notable Vulnerable Software Vendors: [\'GNU\']\n\nExample {idx}: Checklist:\n[\n\t"Identify Usage of `iconv()` Function: Review the application code or dependencies. Is the `iconv()` function used? Look particularly for conversions involving the ISO-2022-CN-EXT character set. This function is the specific target of the vulnerability.",\n\t"Assess Data Handling and Boundary Conditions: Since the vulnerability involves an out-of-bounds write, it\'s crucial to analyze how data boundaries are handled in the code. Are there any custom implementations or patches that might mitigate boundary issues around buffer sizes?",\n\t"Review Application\'s Character Encoding Needs: Does the application specifically need to handle the ISO-2022-CN-EXT character set? If not, consider disabling this character set or using alternative safe functions or libraries for character set conversions.",\n\t"Evaluate Network Exposure and Attack Surface: Are the affected services exposed to the network? If so, this could increase the risk of exploitation. Additionally, if the application using the `iconv()` function is accessible externally, the risk is higher."\n]',
"Example {idx}: CVE Details:\n- CVE ID: GHSA-8ghj-p4vj-mr35\n- CVE description: An issue was discovered in Pillow before 10.0.0. It is a Denial of Service that uncontrollably allocates memory to process a given task, potentially causing a service to crash by having it run out of memory. This occurs for truetype in ImageFont when textlength in an ImageDraw instance operates on a long text argument.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n- CWE Name: CWE-770: Allocation of Resources Without Limits or Throttling (4.14)\n- CWE Description: The product allocates a reusable resource or group of resources on behalf of an actor without imposing any restrictions on the size or number of resources that can be allocated, in violation of the intended security policy for that actor.\n- Code frequently has to work with limited resources, so programmers must be careful to ensure that resources are not consumed too quickly, or too easily. Without use of quotas, resource limits, or other protection mechanisms, it can be easy for an attacker to consume many resources by rapidly making many requests, or causing larger resources to be used than is needed. When too many resources are allocated, or if a single resource is too large, then it can prevent the code from working correctly, possibly leading to a denial of service.\n- Notable Vulnerable Software Vendors: ['Fedoraproject', 'Python']\n- GHSA Summary: Pillow Denial of Service vulnerability\n- GHSA Details: [<'first_patched_version': '10.0.0', 'package': <'ecosystem': 'pip', 'name': 'pillow'>, 'vulnerable_functions': ['PIL.ImageFont'], 'vulnerable_version_range': '>= 0, < 10.0.0'>]\n- GHSA CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n\nExample {idx}: Checklist:\n[\n\t\"Assess Usage of Vulnerable Functions: Specifically, the vulnerability is related to the `PIL.ImageFont` module when processing long text arguments. Does the application code or dependencies use this module and functionality? If your applications use this module to process user-supplied or uncontrolled text inputs, they are likely at risk.\",\n\t\"Evaluate Resource Limits: The vulnerability leads to a denial of service through memory exhaustion. Are there any resource limits set at the container level (e.g., using Docker or Kubernetes settings) that might mitigate the impact of such an attack? Consider setting or reviewing memory limits to prevent a single container from consuming all available system resources.\"\n]",
+ 'Example {idx}: CVE Details:\n- CVE ID: CVE-2019-10086\n- CVE description: In Apache Commons BeanUtils 1.9.2 and earlier, the default property descriptors allow access to class loader via the class property, enabling remote code execution.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:L/I:L/A:L\n- CWE Name: CWE-502: Deserialization of Untrusted Data (4.14)\n- CWE Description: The product deserializes untrusted data without sufficiently verifying that the resulting data will be valid.\n- Known Affected Software: commons-beanutils:commons-beanutils versions before 1.9.4\n- Notable Vulnerable Software Vendors: [\'Apache\']\n\nExample {idx}: Checklist:\n[\n\t"Is the specific vulnerable method `org.apache.commons.beanutils.PropertyUtilsBean.getProperty` from the `commons-beanutils:commons-beanutils` package reachable from the application code? This is the exact method that enables class loader access via property descriptors.",\n\t"Is the installed version of `commons-beanutils:commons-beanutils` within the vulnerable range (before 1.9.4)? Use dependency analysis tools to verify the exact version.",\n\t"Does the application pass user-controlled or untrusted input to BeanUtils property access methods? If only internal configuration data is used, the attack vector is not present."\n]',
+ 'Example {idx}: CVE Details:\n- CVE ID: CVE-2022-41723\n- CVE description: A maliciously crafted HTTP/2 stream could cause excessive CPU consumption in the HPACK decoder in the golang.org/x/net/http2 package, leading to denial of service.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n- CWE Name: CWE-400: Uncontrolled Resource Consumption (4.14)\n- CWE Description: The product does not properly control the allocation and maintenance of a limited resource, thereby enabling an actor to influence the amount of resources consumed, eventually leading to the exhaustion of available resources.\n- Notable Vulnerable Software Vendors: [\'Golang\']\n\nExample {idx}: Checklist:\n[\n\t"Is the `golang.org/x/net/http2` package imported or used in the application code? Specifically, is the `hpack.Decoder.DecodeFull` method reachable from the application\'s HTTP server or client?",\n\t"Is the installed version of `golang.org/x/net` within the vulnerable range (before 0.7.0)? Check `go.mod` or `go.sum` for the exact version.",\n\t"Does the application expose HTTP/2 endpoints that accept connections from untrusted clients? If the application only uses HTTP/1.1, the attack vector is not present."\n]',
+ 'Example {idx}: CVE Details:\n- CVE ID: CVE-2021-3807\n- CVE description: ansi-regex before 5.0.1 and 6.x before 6.0.1 is vulnerable to Regular Expression Denial of Service (ReDoS) due to an inefficient regular expression when processing crafted ANSI escape codes.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n- CWE Name: CWE-1333: Inefficient Regular Expression Complexity (4.14)\n- CWE Description: The product uses a regular expression with an inefficient, possibly exponential worst-case computational complexity that consumes excessive CPU cycles.\n- Notable Vulnerable Software Vendors: [\'chalk\']\n\nExample {idx}: Checklist:\n[\n\t"Is the `ansi-regex` npm package imported or required in the application code? Specifically, is the `ansiRegex()` function called with user-controlled input?",\n\t"Is the installed version of `ansi-regex` within the vulnerable range (before 5.0.1 or 6.x before 6.0.1)? Check `package.json` or `node_modules/ansi-regex/package.json` for the exact version.",\n\t"Does the application process ANSI escape codes from untrusted sources (e.g., user input in terminal output, log processing)? If the application only processes its own internal strings, the ReDoS attack vector is not present."\n]',
]
ex_statements = [
@@ -515,6 +559,9 @@ def build_prompt(self) -> str:
'Example {idx}: CVE Details:\n- CVE ID: CVE-2023-5363\n- CVE description: Issue summary: A bug has been identified in the processing of key and initialisation vector (IV) lengths. This can lead to potential truncation or overruns during the initialisation of some symmetric ciphers. Impact summary: A truncation in the IV can result in non-uniqueness, which could result in loss of confidentiality for some cipher modes. When calling EVP_EncryptInit_ex2(), EVP_DecryptInit_ex2() or EVP_CipherInit_ex2() the provided OSSL_PARAM array is processed after the key and IV have been established. Any alterations to the key length, via the "keylen" parameter or the IV length, via the "ivlen" parameter, within the OSSL_PARAM array will not take effect as intended, potentially causing truncation or overreading of these values. The following ciphers and cipher modes are impacted: RC2, RC4, RC5, CCM, GCM and OCB. For the CCM, GCM and OCB cipher modes, truncation of the IV can result in loss of confidentiality. For example, when following NIST\'s SP 800-38D section 8.2.1 guidance for constructing a deterministic IV for AES in GCM mode, truncation of the counter portion could lead to IV reuse. Both truncations and overruns of the key and overruns of the IV will produce incorrect results and could, in some cases, trigger a memory exception. However, these issues are not currently assessed as security critical.\nChanging the key and/or IV lengths is not considered to be a common operation and the vulnerable API was recently introduced. Furthermore it is likely that application developers will have spotted this problem during testing since decryption would fail unless both peers in the communication were similarly vulnerable. For these reasons we expect the probability of an application being vulnerable to this to be quite low. However if an application is vulnerable then this issue is considered very serious. For these reasons we have assessed this issue as Moderate severity overall.\nThe OpenSSL SSL/TLS implementation is not affected by this issue. The OpenSSL 3.0 and 3.1 FIPS providers are not affected by this because the issue lies outside of the FIPS provider boundary. OpenSSL 3.1 and 3.0 are vulnerable to this issue.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:N/A:N\n- Notable Vulnerable Software Vendors: [\'Debian\', \'Netapp\', \'Openssl\']\n\nExample {idx}: Checklist:\n[\n\t"Identify Affected Cipher Modes: Determine if the application within the container image uses any of the affected cipher modes: RC2, RC4, RC5, CCM, GCM, or OCB. Special attention should be given to applications using CCM, GCM, and OCB modes as these are particularly noted for potential loss of confidentiality due to IV truncation.",\n\t"Review Cryptographic Operations: Examine the code or configuration of applications using OpenSSL for any instances where `EVP_EncryptInit_ex2()`, `EVP_DecryptInit_ex2()`, or `EVP_CipherInit_ex2()` are called. Check if there are any modifications to the `keylen` or `ivlen` parameters after initialization which might not be taking effect as intended.",\n\t"Check for Custom Cryptographic Implementations: Since changing the key and/or IV lengths is not a common operation and the issue is in a recently introduced API, it\'s crucial to identify if any custom cryptographic implementations might be performing such operations. This is less likely but should be checked especially in bespoke or highly customized applications."\n]',
'Example {idx}: CVE Details:\n- CVE ID: CVE-2024-2961\n- CVE description: The iconv() function in the GNU C Library versions 2.39 and older may overflow the output buffer passed to it by up to 4 bytes when converting strings to the ISO-2022-CN-EXT character set, which may be used to crash an application or overwrite a neighbouring variable.\n- CWE Name: CWE-787: Out-of-bounds Write (4.14)\n- CWE Description: The product writes data past the end, or before the beginning, of the intended buffer.\nTypically, this can result in corruption of data, a crash, or code execution. The product may modify an index or perform pointer arithmetic that references a memory location that is outside of the boundaries of the buffer. A subsequent write operation then produces undefined or unexpected results.\n- Notable Vulnerable Software Vendors: [\'GNU\']\n\nExample {idx}: Checklist:\n[\n\t"Identify Usage of `iconv()` Function: Review the application code or dependencies to check if the `iconv()` function is used, particularly for conversions involving the ISO-2022-CN-EXT character set. This function is the specific target of the vulnerability.",\n\t"Assess Data Handling and Boundary Conditions: Since the vulnerability involves an out-of-bounds write, it\'s crucial to analyze how data boundaries are handled in the code. Look for any custom implementations or patches that might mitigate boundary issues around buffer sizes.",\n\t"Review Application\'s Character Encoding Needs: Determine if the application specifically needs to handle the ISO-2022-CN-EXT character set. If not, consider disabling this character set or using alternative safe functions or libraries for character set conversions.",\n\t"Evaluate Network Exposure and Attack Surface: Consider whether the affected services are exposed to the network, which could increase the risk of exploitation. If the application using the `iconv()` function is accessible externally, the risk is higher."\n]',
"Example {idx}: CVE Details:\n- CVE ID: GHSA-8ghj-p4vj-mr35\n- CVE description: An issue was discovered in Pillow before 10.0.0. It is a Denial of Service that uncontrollably allocates memory to process a given task, potentially causing a service to crash by having it run out of memory. This occurs for truetype in ImageFont when textlength in an ImageDraw instance operates on a long text argument.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n- CWE Name: CWE-770: Allocation of Resources Without Limits or Throttling (4.14)\n- CWE Description: The product allocates a reusable resource or group of resources on behalf of an actor without imposing any restrictions on the size or number of resources that can be allocated, in violation of the intended security policy for that actor.\n- Code frequently has to work with limited resources, so programmers must be careful to ensure that resources are not consumed too quickly, or too easily. Without use of quotas, resource limits, or other protection mechanisms, it can be easy for an attacker to consume many resources by rapidly making many requests, or causing larger resources to be used than is needed. When too many resources are allocated, or if a single resource is too large, then it can prevent the code from working correctly, possibly leading to a denial of service.\n- Notable Vulnerable Software Vendors: ['Fedoraproject', 'Python']\n- GHSA Summary: Pillow Denial of Service vulnerability\n- GHSA Details: [<'first_patched_version': '10.0.0', 'package': <'ecosystem': 'pip', 'name': 'pillow'>, 'vulnerable_functions': ['PIL.ImageFont'], 'vulnerable_version_range': '>= 0, < 10.0.0'>]\n- GHSA CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n\nExample {idx}: Checklist:\n[\n\t\"Assess Usage of Vulnerable Functions: Specifically, the vulnerability is related to the `PIL.ImageFont` module when processing long text arguments. Review the application code or dependencies to see if this module and functionality are used. If your applications use this module to process user-supplied or uncontrolled text inputs, they are likely at risk.\",\n\t\"Evaluate Resource Limits: Since the vulnerability leads to a denial of service through memory exhaustion, check if there are any resource limits set at the container level (e.g., using Docker or Kubernetes settings) that might mitigate the impact of such an attack. Consider setting or reviewing memory limits to prevent a single container from consuming all available system resources.\"\n]",
+ 'Example {idx}: CVE Details:\n- CVE ID: CVE-2019-10086\n- CVE description: In Apache Commons BeanUtils 1.9.2 and earlier, the default property descriptors allow access to class loader via the class property, enabling remote code execution.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:L/I:L/A:L\n- CWE Name: CWE-502: Deserialization of Untrusted Data (4.14)\n- CWE Description: The product deserializes untrusted data without sufficiently verifying that the resulting data will be valid.\n- Known Affected Software: commons-beanutils:commons-beanutils versions before 1.9.4\n- Notable Vulnerable Software Vendors: [\'Apache\']\n\nExample {idx}: Checklist:\n[\n\t"Is the specific vulnerable method `org.apache.commons.beanutils.PropertyUtilsBean.getProperty` from the `commons-beanutils:commons-beanutils` package reachable from the application code? This is the exact method that enables class loader access via property descriptors.",\n\t"Is the installed version of `commons-beanutils:commons-beanutils` within the vulnerable range (before 1.9.4)? Use dependency analysis tools to verify the exact version.",\n\t"Does the application pass user-controlled or untrusted input to BeanUtils property access methods? If only internal configuration data is used, the attack vector is not present."\n]',
+ 'Example {idx}: CVE Details:\n- CVE ID: CVE-2022-41723\n- CVE description: A maliciously crafted HTTP/2 stream could cause excessive CPU consumption in the HPACK decoder in the golang.org/x/net/http2 package, leading to denial of service.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n- CWE Name: CWE-400: Uncontrolled Resource Consumption (4.14)\n- CWE Description: The product does not properly control the allocation and maintenance of a limited resource, thereby enabling an actor to influence the amount of resources consumed, eventually leading to the exhaustion of available resources.\n- Notable Vulnerable Software Vendors: [\'Golang\']\n\nExample {idx}: Checklist:\n[\n\t"Is the `golang.org/x/net/http2` package imported or used in the application code? Specifically, is the `hpack.Decoder.DecodeFull` method reachable from the application\'s HTTP server or client?",\n\t"Is the installed version of `golang.org/x/net` within the vulnerable range (before 0.7.0)? Check `go.mod` or `go.sum` for the exact version.",\n\t"Does the application expose HTTP/2 endpoints that accept connections from untrusted clients? If the application only uses HTTP/1.1, the attack vector is not present."\n]',
+ 'Example {idx}: CVE Details:\n- CVE ID: CVE-2021-3807\n- CVE description: ansi-regex before 5.0.1 and 6.x before 6.0.1 is vulnerable to Regular Expression Denial of Service (ReDoS) due to an inefficient regular expression when processing crafted ANSI escape codes.\n- CVSS Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H\n- CWE Name: CWE-1333: Inefficient Regular Expression Complexity (4.14)\n- CWE Description: The product uses a regular expression with an inefficient, possibly exponential worst-case computational complexity that consumes excessive CPU cycles.\n- Notable Vulnerable Software Vendors: [\'chalk\']\n\nExample {idx}: Checklist:\n[\n\t"Is the `ansi-regex` npm package imported or required in the application code? Specifically, is the `ansiRegex()` function called with user-controlled input?",\n\t"Is the installed version of `ansi-regex` within the vulnerable range (before 5.0.1 or 6.x before 6.0.1)? Check `package.json` or `node_modules/ansi-regex/package.json` for the exact version.",\n\t"Does the application process ANSI escape codes from untrusted sources (e.g., user input in terminal output, log processing)? If the application only processes its own internal strings, the ReDoS attack vector is not present."\n]',
]
FEW_SHOT = """Generate a checklist for a security analyst to use when assessing the exploitability of a specific CVE within a containerized environment. Use the provided examples as a guide to understand how to construct a checklist from a given set of CVE details, then apply this understanding to create a specific checklist for the CVE details provided below. All output should be a comma separated list enclosed in square brackets with each list item enclosed in quotes.
@@ -546,10 +593,10 @@ def build_prompt(self) -> str:
Example 2: Checklist:
[
- “Vulnerable package check. Does the project use the lxml library, which is the affected package? If lxml is not a dependency in your project, then your code is not vulnerable to this CVE.”,
- “Vulnerable version check. Is the version of lxml that the project depends on vulnerable? According to the vulnerability details, versions 4.9.0 and earlier are vulnerable.”,
- “Vulnerable version check of connected dependency. Is the version of libxml, the connected dependency, that the project depends on vulnerable? The package is only vulnerable if libxml 2.9.10 through 2.9.14 is also present.”,
- “Review code for vulnerable functionality. The library is vulnerable through its `iterwalk` function, which is also utilized by the `canonicalize` function. Are either of these functions used in your code base?”
+ "Vulnerable package check. Does the project use the lxml library, which is the affected package? If lxml is not a dependency in your project, then your code is not vulnerable to this CVE.",
+ "Vulnerable version check. Is the version of lxml that the project depends on vulnerable? According to the vulnerability details, versions 4.9.0 and earlier are vulnerable.",
+ "Vulnerable version check of connected dependency. Is the version of libxml, the connected dependency, that the project depends on vulnerable? The package is only vulnerable if libxml 2.9.10 through 2.9.14 is also present.",
+ "Review code for vulnerable functionality. The library is vulnerable through its `iterwalk` function, which is also utilized by the `canonicalize` function. Are either of these functions used in your code base?"
]
Given CVE Details:
@@ -574,11 +621,18 @@ def build_prompt(self) -> str:
checklist item must verify whether that function in that package or library is called or imported
in the codebase - function should be specified together with the package name,
for example : 'Is the function1 function from the package1 package called in the codebase?'
- - Focus on exploitability factors (version presence is already confirmed)
+ - When advisory data includes version ranges or fixed versions, verify
+ the installed version is within the vulnerable range
+ - Focus on exploitability factors: presence, usage, reachability, and mitigations
- Include specific technical names from the CVE (functions, libraries,
configurations, cipher modes, etc.)
- Consider the attack vector (network exposure, user input, file processing, etc.)
- Address relevant security controls or mitigations
+ - Every question must be directly relevant to the specific vulnerable component
+ and package mentioned in the CVE. Do NOT investigate components from unrelated
+ packages or ask generic security practice questions (e.g., generic input validation,
+ generic deserialization practices) unless the CVE's attack vector specifically
+ requires it
3. INVESTIGATION TOOLS AVAILABLE:
{tool_descriptions}
@@ -586,8 +640,11 @@ def build_prompt(self) -> str:
Design questions that can be answered using these analysis capabilities.
4. COMPLETENESS:
- - Cover the vulnerability chain: presence → usage → exploitability
+ - Cover the vulnerability chain: presence → version check → usage → exploitability
- Each item should independently contribute to understanding exploit risk
+ - Stay within scope: only investigate the specific package and components
+ described in the CVE advisory. Do not branch into unrelated libraries or
+ generic security assessments
@@ -639,11 +696,39 @@ def build_prompt(self) -> str:
""" + additional_intel_prompting
-def get_mod_examples(type='questions', choices=[0, 1]):
+# Mapping from ecosystem to relevant example indices
+# Each ecosystem gets one baseline Python/C example (index 0) plus one ecosystem-specific example
+_ECOSYSTEM_EXAMPLE_CHOICES = {
+ 'python': [0, 1], # Two Python examples (urllib.parse, email.utils.parseaddr)
+ 'c': [3, 4], # Two C/C++ examples (OpenSSL, iconv)
+ 'java': [0, 6], # One Python baseline + Java (commons-beanutils)
+ 'go': [0, 7], # One Python baseline + Go (golang.org/x/net)
+ 'javascript': [0, 8], # One Python baseline + JavaScript (ansi-regex)
+}
+
+
+def get_mod_examples(type='questions', choices=None, ecosystem=None):
+ """
+ Get few-shot examples for the checklist prompt.
+
+ Args:
+ type: 'questions' or 'statements' - which example set to use
+ choices: Explicit list of example indices to include. Overrides ecosystem.
+ ecosystem: Ecosystem string (e.g., 'java', 'python', 'go', 'javascript', 'c').
+ Used to select relevant examples if choices is not provided.
+ """
if type == 'questions':
- ex_list = [q for idx, q in enumerate(ex_questions) if idx in choices]
+ ex_source = ex_questions
else:
- ex_list = [s for idx, s in enumerate(ex_statements) if idx in choices]
+ ex_source = ex_statements
+
+ if choices is None:
+ if ecosystem and ecosystem.lower() in _ECOSYSTEM_EXAMPLE_CHOICES:
+ choices = _ECOSYSTEM_EXAMPLE_CHOICES[ecosystem.lower()]
+ else:
+ choices = [0, 1] # Default: first two Python examples
+
+ ex_list = [q for idx, q in enumerate(ex_source) if idx in choices]
examples = '\n'.join(q.format(idx=idx + 1) for idx, q in enumerate(ex_list))
- return examples
+ return examples
\ No newline at end of file
diff --git a/tests/test_base_tool_descriptions.py b/tests/test_base_tool_descriptions.py
index 9a939aa5..a024cd27 100644
--- a/tests/test_base_tool_descriptions.py
+++ b/tests/test_base_tool_descriptions.py
@@ -58,7 +58,8 @@ def test_base_all_tools():
ToolNames.CALL_CHAIN_ANALYZER,
ToolNames.FUNCTION_CALLER_FINDER,
ToolNames.CVE_WEB_SEARCH,
- ToolNames.CONTAINER_ANALYSIS_DATA
+ ToolNames.CONTAINER_ANALYSIS_DATA,
+ ToolNames.FUNCTION_LIBRARY_VERSION_FINDER
]
result = build_tool_descriptions(tool_names)
@@ -75,7 +76,8 @@ def test_base_all_tools():
assert "Function Caller Finder" in all_text
assert "CVE Web Search" in all_text
assert "Container Analysis Data" in all_text
-
+ assert "Function Library Version Finder" in all_text
+
print("✓ Base function includes all tools")
@@ -141,15 +143,40 @@ def test_mod_few_shot_structure():
print("✓ MOD_FEW_SHOT structure validated")
+def test_cve_web_search_description_warns_about_versions():
+ """Test that CVE Web Search description includes version warning."""
+ tool_names = [ToolNames.CVE_WEB_SEARCH]
+ result = build_tool_descriptions(tool_names)
+ assert len(result) == 1
+ desc = result[0]
+ assert "MULTIPLE library versions" in desc
+ assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER in desc
+
+ print("✓ CVE Web Search description includes version warning")
+
+
+def test_agent_prompt_contains_version_awareness_instructions():
+ """Test that agent prompt template contains version awareness instructions."""
+ from vuln_analysis.utils.prompting import get_agent_prompt
+
+ prompt = get_agent_prompt()
+ assert "VERSION CHECK" in prompt
+ assert "Function Library Version Finder" in prompt
+
+ print("✓ Agent prompt contains version awareness instructions")
+
+
if __name__ == "__main__":
print("Running Base Tool Descriptions tests...\n")
-
+
test_base_returns_list()
test_base_descriptions_format()
test_base_all_tools()
test_base_empty_list()
test_checklist_formats_descriptions()
test_mod_few_shot_structure()
-
+ test_cve_web_search_description_warns_about_versions()
+ test_agent_prompt_contains_version_awareness_instructions()
+
print("\n✅ All base tool descriptions tests passed!")
diff --git a/uv.lock b/uv.lock
index 152420d8..f10ad033 100644
--- a/uv.lock
+++ b/uv.lock
@@ -5343,6 +5343,7 @@ dependencies = [
{ name = "gitpython" },
{ name = "google-search-results" },
{ name = "json5" },
+ { name = "jsonschema" },
{ name = "litellm" },
{ name = "nbformat" },
{ name = "nemollm" },
@@ -5389,6 +5390,7 @@ requires-dist = [
{ name = "gitpython" },
{ name = "google-search-results", specifier = "==2.4" },
{ name = "json5" },
+ { name = "jsonschema", specifier = ">=4.0.0,<5.0.0" },
{ name = "litellm", specifier = "<=1.75.8" },
{ name = "nbformat" },
{ name = "nemollm" },