Skip to content

Commit c238dba

Browse files
committed
make the way to count abstractions more robust
1 parent c765bff commit c238dba

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

nodes.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -247,45 +247,54 @@ def prep(self, shared):
247247
language = shared.get("language", "english") # Get language
248248
use_cache = shared.get("use_cache", True) # Get use_cache flag, default to True
249249

250+
# Get the actual number of abstractions directly
251+
num_abstractions = len(abstractions)
252+
250253
# Create context with abstraction names, indices, descriptions, and relevant file snippets
251-
context = "Identified Abstractions:\n"
254+
context = "Identified Abstractions:\\n"
252255
all_relevant_indices = set()
253256
abstraction_info_for_prompt = []
254257
for i, abstr in enumerate(abstractions):
255258
# Use 'files' which contains indices directly
256259
file_indices_str = ", ".join(map(str, abstr["files"]))
257260
# Abstraction name and description might be translated already
258-
info_line = f"- Index {i}: {abstr['name']} (Relevant file indices: [{file_indices_str}])\n Description: {abstr['description']}"
259-
context += info_line + "\n"
261+
info_line = f"- Index {i}: {abstr['name']} (Relevant file indices: [{file_indices_str}])\\n Description: {abstr['description']}"
262+
context += info_line + "\\n"
260263
abstraction_info_for_prompt.append(
261264
f"{i} # {abstr['name']}"
262265
) # Use potentially translated name here too
263266
all_relevant_indices.update(abstr["files"])
264267

265-
context += "\nRelevant File Snippets (Referenced by Index and Path):\n"
268+
context += "\\nRelevant File Snippets (Referenced by Index and Path):\\n"
266269
# Get content for relevant files using helper
267270
relevant_files_content_map = get_content_for_indices(
268271
files_data, sorted(list(all_relevant_indices))
269272
)
270273
# Format file content for context
271-
file_context_str = "\n\n".join(
272-
f"--- File: {idx_path} ---\n{content}"
274+
file_context_str = "\\n\\n".join(
275+
f"--- File: {idx_path} ---\\n{content}"
273276
for idx_path, content in relevant_files_content_map.items()
274277
)
275278
context += file_context_str
276279

277280
return (
278281
context,
279282
"\n".join(abstraction_info_for_prompt),
283+
num_abstractions, # Pass the actual count
280284
project_name,
281285
language,
282286
use_cache,
283287
) # Return use_cache
284288

285289
def exec(self, prep_res):
286-
context, abstraction_listing, project_name, language, use_cache = (
287-
prep_res # Unpack use_cache
288-
)
290+
(
291+
context,
292+
abstraction_listing,
293+
num_abstractions, # Receive the actual count
294+
project_name,
295+
language,
296+
use_cache,
297+
) = prep_res # Unpack use_cache
289298
print(f"Analyzing relationships using LLM...")
290299

291300
# Add language instruction and hints only if not English
@@ -335,7 +344,7 @@ def exec(self, prep_res):
335344
336345
Now, provide the YAML output:
337346
"""
338-
response = call_llm(prompt)
347+
response = call_llm(prompt, use_cache=use_cache)
339348

340349
# --- Validation ---
341350
yaml_str = response.strip().split("```yaml")[1].split("```")[0].strip()
@@ -354,7 +363,6 @@ def exec(self, prep_res):
354363

355364
# Validate relationships structure
356365
validated_relationships = []
357-
num_abstractions = len(abstraction_listing.split("\n"))
358366
for rel in relationships_data["relationships"]:
359367
# Check for 'label' key
360368
if not isinstance(rel, dict) or not all(

utils/call_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def call_llm(prompt: str, use_cache: bool = True) -> str:
4747
logger.info(f"RESPONSE: {cache[prompt]}")
4848
return cache[prompt]
4949

50-
# Call the LLM if not in cache or cache disabled
50+
# # Call the LLM if not in cache or cache disabled
5151
# client = genai.Client(
5252
# vertexai=True,
5353
# # TODO: change to your own project id and location
@@ -59,8 +59,8 @@ def call_llm(prompt: str, use_cache: bool = True) -> str:
5959
client = genai.Client(
6060
api_key=os.getenv("GEMINI_API_KEY", ""),
6161
)
62-
# model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro-exp-03-25")
63-
model = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-exp")
62+
model = os.getenv("GEMINI_MODEL", "gemini-2.5-pro-exp-03-25")
63+
# model = os.getenv("GEMINI_MODEL", "gemini-2.5-flash-preview-04-17")
6464

6565
response = client.models.generate_content(model=model, contents=[prompt])
6666
response_text = response.text

0 commit comments

Comments
 (0)