More PR-Agent commands \n\n"
output += HelpMessage.get_general_bot_help_text()
@@ -175,7 +175,6 @@ You can ask questions about the entire PR, about specific code lines, or about a
return output
-
@staticmethod
def get_improve_usage_guide():
output = "**Overview:**\n"
diff --git a/apps/utils/pr_agent/servers/utils.py b/apps/utils/pr_agent/servers/utils.py
index 4b1ea80..4b3c788 100644
--- a/apps/utils/pr_agent/servers/utils.py
+++ b/apps/utils/pr_agent/servers/utils.py
@@ -18,8 +18,12 @@ def verify_signature(payload_body, secret_token, signature_header):
signature_header: header received from GitHub (x-hub-signature-256)
"""
if not signature_header:
- raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!")
- hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256)
+ raise HTTPException(
+ status_code=403, detail="x-hub-signature-256 header is missing!"
+ )
+ hash_object = hmac.new(
+ secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256
+ )
expected_signature = "sha256=" + hash_object.hexdigest()
if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
@@ -27,6 +31,7 @@ def verify_signature(payload_body, secret_token, signature_header):
class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded."""
+
pass
@@ -66,7 +71,11 @@ class DefaultDictWithTimeout(defaultdict):
request_time = self.__time()
if request_time - self.__last_refresh > self.__refresh_interval:
return
- to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl]
+ to_delete = [
+ key
+ for key, key_time in self.__key_times.items()
+ if request_time - key_time > self.__ttl
+ ]
for key in to_delete:
del self[key]
self.__last_refresh = request_time
diff --git a/apps/utils/pr_agent/tools/pr_add_docs.py b/apps/utils/pr_agent/tools/pr_add_docs.py
index 362e8b5..be59838 100644
--- a/apps/utils/pr_agent/tools/pr_add_docs.py
+++ b/apps/utils/pr_agent/tools/pr_add_docs.py
@@ -17,9 +17,13 @@ from utils.pr_agent.log import get_logger
class PRAddDocs:
- def __init__(self, pr_url: str, cli_mode=False, args: list = None,
- ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
-
+ def __init__(
+ self,
+ pr_url: str,
+ cli_mode=False,
+ args: list = None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
@@ -39,13 +43,16 @@ class PRAddDocs:
"diff": "", # empty diff for initial calculation
"extra_instructions": get_settings().pr_add_docs.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
- 'docs_for_language': get_docs_for_language(self.main_language,
- get_settings().pr_add_docs.docs_style),
+ 'docs_for_language': get_docs_for_language(
+ self.main_language, get_settings().pr_add_docs.docs_style
+ ),
}
- self.token_handler = TokenHandler(self.git_provider.pr,
- self.vars,
- get_settings().pr_add_docs_prompt.system,
- get_settings().pr_add_docs_prompt.user)
+ self.token_handler = TokenHandler(
+ self.git_provider.pr,
+ self.vars,
+ get_settings().pr_add_docs_prompt.system,
+ get_settings().pr_add_docs_prompt.user,
+ )
async def run(self):
try:
@@ -66,16 +73,20 @@ class PRAddDocs:
get_logger().info('Pushing inline code documentation...')
self.push_inline_docs(data)
except Exception as e:
- get_logger().error(f"Failed to generate code documentation for PR, error: {e}")
+ get_logger().error(
+ f"Failed to generate code documentation for PR, error: {e}"
+ )
async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...')
- self.patches_diff = get_pr_diff(self.git_provider,
- self.token_handler,
- model,
- add_line_numbers_to_hunks=True,
- disable_extra_lines=False)
+ self.patches_diff = get_pr_diff(
+ self.git_provider,
+ self.token_handler,
+ model,
+ add_line_numbers_to_hunks=True,
+ disable_extra_lines=False,
+ )
get_logger().info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
@@ -84,13 +95,21 @@ class PRAddDocs:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables)
- user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables)
+ system_prompt = environment.from_string(
+ get_settings().pr_add_docs_prompt.system
+ ).render(variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_add_docs_prompt.user
+ ).render(variables)
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
get_logger().info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(
- model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
+ model=model,
+ temperature=get_settings().config.temperature,
+ system=system_prompt,
+ user=user_prompt,
+ )
return response
@@ -105,7 +124,9 @@ class PRAddDocs:
docs = []
if not data['Code Documentation']:
- return self.git_provider.publish_comment('No code documentation found to improve this PR.')
+ return self.git_provider.publish_comment(
+ 'No code documentation found to improve this PR.'
+ )
for d in data['Code Documentation']:
try:
@@ -116,32 +137,59 @@ class PRAddDocs:
documentation = d['documentation']
doc_placement = d['doc placement'].strip()
if documentation:
- new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement,
- add_original_line=True)
+ new_code_snippet = self.dedent_code(
+ relevant_file,
+ relevant_line,
+ documentation,
+ doc_placement,
+ add_original_line=True,
+ )
- body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```"
- docs.append({'body': body, 'relevant_file': relevant_file,
- 'relevant_lines_start': relevant_line,
- 'relevant_lines_end': relevant_line})
+ body = (
+ f"**Suggestion:** Proposed documentation\n```suggestion\n"
+ + new_code_snippet
+ + "\n```"
+ )
+ docs.append(
+ {
+ 'body': body,
+ 'relevant_file': relevant_file,
+ 'relevant_lines_start': relevant_line,
+ 'relevant_lines_end': relevant_line,
+ }
+ )
except Exception:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not parse code docs: {d}")
is_successful = self.git_provider.publish_code_suggestions(docs)
if not is_successful:
- get_logger().info("Failed to publish code docs, trying to publish each docs separately")
+ get_logger().info(
+ "Failed to publish code docs, trying to publish each docs separately"
+ )
for doc_suggestion in docs:
self.git_provider.publish_code_suggestions([doc_suggestion])
- def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after',
- add_original_line=False):
+ def dedent_code(
+ self,
+ relevant_file,
+ relevant_lines_start,
+ new_code_snippet,
+ doc_placement='after',
+ add_original_line=False,
+ ):
try: # dedent code snippet
- self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
+ self.diff_files = (
+ self.git_provider.diff_files
+ if self.git_provider.diff_files
else self.git_provider.get_diff_files()
+ )
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
- original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
+ original_initial_line = file.head_file.splitlines()[
+ relevant_lines_start - 1
+ ]
break
if original_initial_line:
if doc_placement == 'after':
@@ -150,18 +198,28 @@ class PRAddDocs:
line = original_initial_line
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(line) - len(line.lstrip())
- suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
+ suggested_initial_spaces = len(suggested_initial_line) - len(
+ suggested_initial_line.lstrip()
+ )
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
- new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
+ new_code_snippet = textwrap.indent(
+ new_code_snippet, delta_spaces * " "
+ ).rstrip('\n')
if add_original_line:
if doc_placement == 'after':
- new_code_snippet = original_initial_line + "\n" + new_code_snippet
+ new_code_snippet = (
+ original_initial_line + "\n" + new_code_snippet
+ )
else:
- new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line
+ new_code_snippet = (
+ new_code_snippet.rstrip() + "\n" + original_initial_line
+ )
except Exception as e:
if get_settings().config.verbosity_level >= 2:
- get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
+ get_logger().info(
+ f"Could not dedent code snippet for file {relevant_file}, error: {e}"
+ )
return new_code_snippet
diff --git a/apps/utils/pr_agent/tools/pr_code_suggestions.py b/apps/utils/pr_agent/tools/pr_code_suggestions.py
index b0dd5c7..9fce273 100644
--- a/apps/utils/pr_agent/tools/pr_code_suggestions.py
+++ b/apps/utils/pr_agent/tools/pr_code_suggestions.py
@@ -12,15 +12,25 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
-from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
- get_pr_diff, get_pr_multi_diffs,
- retry_with_fallback_models)
+from utils.pr_agent.algo.pr_processing import (
+ add_ai_metadata_to_diff_files,
+ get_pr_diff,
+ get_pr_multi_diffs,
+ retry_with_fallback_models,
+)
from utils.pr_agent.algo.token_handler import TokenHandler
-from utils.pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags,
- show_relevant_configurations)
+from utils.pr_agent.algo.utils import (
+ ModelType,
+ load_yaml,
+ replace_code_tags,
+ show_relevant_configurations,
+)
from utils.pr_agent.config_loader import get_settings
-from utils.pr_agent.git_providers import (AzureDevopsProvider, GithubProvider,
- get_git_provider_with_context)
+from utils.pr_agent.git_providers import (
+ AzureDevopsProvider,
+ GithubProvider,
+ get_git_provider_with_context,
+)
from utils.pr_agent.git_providers.git_provider import get_main_pr_language, GitProvider
from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage
@@ -28,9 +38,13 @@ from utils.pr_agent.tools.pr_description import insert_br_after_x_chars
class PRCodeSuggestions:
- def __init__(self, pr_url: str, cli_mode=False, args: list = None,
- ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
-
+ def __init__(
+ self,
+ pr_url: str,
+ cli_mode=False,
+ args: list = None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
self.git_provider = get_git_provider_with_context(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
@@ -38,10 +52,16 @@ class PRCodeSuggestions:
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
- MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
+ MAX_CONTEXT_TOKENS_IMPROVE = (
+ get_settings().pr_code_suggestions.max_context_tokens
+ )
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
- get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
- get_settings().config.max_model_tokens_original = get_settings().config.max_model_tokens
+ get_logger().info(
+ f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve"
+ )
+ get_settings().config.max_model_tokens_original = (
+ get_settings().config.max_model_tokens
+ )
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# extended mode
@@ -49,8 +69,9 @@ class PRCodeSuggestions:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
- num_code_suggestions = int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk)
-
+ num_code_suggestions = int(
+ get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
+ )
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
@@ -58,10 +79,15 @@ class PRCodeSuggestions:
self.prediction = None
self.pr_url = pr_url
self.cli_mode = cli_mode
- self.pr_description, self.pr_description_files = (
- self.git_provider.get_pr_description(split_changes_walkthrough=True))
- if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
- get_settings().get("config.enable_ai_metadata", False)):
+ (
+ self.pr_description,
+ self.pr_description_files,
+ ) = self.git_provider.get_pr_description(split_changes_walkthrough=True)
+ if (
+ self.pr_description_files
+ and get_settings().get("config.is_auto_command", False)
+ and get_settings().get("config.enable_ai_metadata", False)
+ ):
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
get_logger().debug(f"AI metadata added to the this command")
else:
@@ -80,16 +106,24 @@ class PRCodeSuggestions:
"commit_messages_str": self.git_provider.get_commit_messages(),
"relevant_best_practices": "",
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
- "focus_only_on_problems": get_settings().get("pr_code_suggestions.focus_only_on_problems", False),
+ "focus_only_on_problems": get_settings().get(
+ "pr_code_suggestions.focus_only_on_problems", False
+ ),
"date": datetime.now().strftime('%Y-%m-%d'),
- 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
+ 'duplicate_prompt_examples': get_settings().config.get(
+ 'duplicate_prompt_examples', False
+ ),
}
- self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
+ self.pr_code_suggestions_prompt_system = (
+ get_settings().pr_code_suggestions_prompt.system
+ )
- self.token_handler = TokenHandler(self.git_provider.pr,
- self.vars,
- self.pr_code_suggestions_prompt_system,
- get_settings().pr_code_suggestions_prompt.user)
+ self.token_handler = TokenHandler(
+ self.git_provider.pr,
+ self.vars,
+ self.pr_code_suggestions_prompt_system,
+ get_settings().pr_code_suggestions_prompt.user,
+ )
self.progress = f"## 生成 PR 代码建议\n\n"
self.progress += f"""\n思考中 ... \n """
@@ -98,33 +132,50 @@ class PRCodeSuggestions:
async def run(self):
try:
if not self.git_provider.get_files():
- get_logger().info(f"PR has no files: {self.pr_url}, skipping code suggestions")
+ get_logger().info(
+ f"PR has no files: {self.pr_url}, skipping code suggestions"
+ )
return None
get_logger().info('Generating code suggestions for PR...')
- relevant_configs = {'pr_code_suggestions': dict(get_settings().pr_code_suggestions),
- 'config': dict(get_settings().config)}
+ relevant_configs = {
+ 'pr_code_suggestions': dict(get_settings().pr_code_suggestions),
+ 'config': dict(get_settings().config),
+ }
get_logger().debug("Relevant configs", artifacts=relevant_configs)
# publish "Preparing suggestions..." comments
- if (get_settings().config.publish_output and get_settings().config.publish_output_progress and
- not get_settings().config.get('is_auto_command', False)):
+ if (
+ get_settings().config.publish_output
+ and get_settings().config.publish_output_progress
+ and not get_settings().config.get('is_auto_command', False)
+ ):
if self.git_provider.is_supported("gfm_markdown"):
- self.progress_response = self.git_provider.publish_comment(self.progress)
+ self.progress_response = self.git_provider.publish_comment(
+ self.progress
+ )
else:
self.git_provider.publish_comment("准备建议中...", is_temporary=True)
# call the model to get the suggestions, and self-reflect on them
if not self.is_extended:
- data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
+ data = await retry_with_fallback_models(
+ self._prepare_prediction, model_type=ModelType.REGULAR
+ )
else:
- data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
+ data = await retry_with_fallback_models(
+ self._prepare_prediction_extended, model_type=ModelType.REGULAR
+ )
if not data:
data = {"code_suggestions": []}
self.data = data
# Handle the case where the PR has no suggestions
- if (data is None or 'code_suggestions' not in data or not data['code_suggestions']):
+ if (
+ data is None
+ or 'code_suggestions' not in data
+ or not data['code_suggestions']
+ ):
await self.publish_no_suggestions()
return
@@ -134,20 +185,25 @@ class PRCodeSuggestions:
self.git_provider.remove_initial_comment()
# Publish table summarized suggestions
- if ((not get_settings().pr_code_suggestions.commitable_code_suggestions) and
- self.git_provider.is_supported("gfm_markdown")):
-
+ if (
+ not get_settings().pr_code_suggestions.commitable_code_suggestions
+ ) and self.git_provider.is_supported("gfm_markdown"):
# generate summarized suggestions
pr_body = self.generate_summarized_suggestions(data)
get_logger().debug(f"PR output", artifact=pr_body)
# require self-review
- if get_settings().pr_code_suggestions.demand_code_suggestions_self_review:
+ if (
+ get_settings().pr_code_suggestions.demand_code_suggestions_self_review
+ ):
pr_body = await self.add_self_review_text(pr_body)
# add usage guide
- if (get_settings().pr_code_suggestions.enable_chat_text and get_settings().config.is_auto_command
- and isinstance(self.git_provider, GithubProvider)):
+ if (
+ get_settings().pr_code_suggestions.enable_chat_text
+ and get_settings().config.is_auto_command
+ and isinstance(self.git_provider, GithubProvider)
+ ):
pr_body += "\n\n>💡 Need additional feedback ? start a [PR chat](https://chromewebstore.google.com/detail/ephlnjeghhogofkifjloamocljapahnl) \n\n"
if get_settings().pr_code_suggestions.enable_help_text:
pr_body += " \n\n 💡 Tool usage guide: \n\n"
@@ -155,55 +211,84 @@ class PRCodeSuggestions:
pr_body += "\n \n"
# Output the relevant configurations if enabled
- if get_settings().get('config', {}).get('output_relevant_configurations', False):
- pr_body += show_relevant_configurations(relevant_section='pr_code_suggestions')
+ if (
+ get_settings()
+ .get('config', {})
+ .get('output_relevant_configurations', False)
+ ):
+ pr_body += show_relevant_configurations(
+ relevant_section='pr_code_suggestions'
+ )
# publish the PR comment
- if get_settings().pr_code_suggestions.persistent_comment: # true by default
- self.publish_persistent_comment_with_history(self.git_provider,
- pr_body,
- initial_header="## PR 代码建议 ✨",
- update_header=True,
- name="suggestions",
- final_update_message=False,
- max_previous_comments=get_settings().pr_code_suggestions.max_history_len,
- progress_response=self.progress_response)
+ if (
+ get_settings().pr_code_suggestions.persistent_comment
+ ): # true by default
+ self.publish_persistent_comment_with_history(
+ self.git_provider,
+ pr_body,
+ initial_header="## PR 代码建议 ✨",
+ update_header=True,
+ name="suggestions",
+ final_update_message=False,
+ max_previous_comments=get_settings().pr_code_suggestions.max_history_len,
+ progress_response=self.progress_response,
+ )
else:
if self.progress_response:
- self.git_provider.edit_comment(self.progress_response, body=pr_body)
+ self.git_provider.edit_comment(
+ self.progress_response, body=pr_body
+ )
else:
self.git_provider.publish_comment(pr_body)
# dual publishing mode
- if int(get_settings().pr_code_suggestions.dual_publishing_score_threshold) > 0:
+ if (
+ int(
+ get_settings().pr_code_suggestions.dual_publishing_score_threshold
+ )
+ > 0
+ ):
await self.dual_publishing(data)
else:
await self.push_inline_code_suggestions(data)
if self.progress_response:
self.git_provider.remove_comment(self.progress_response)
else:
- get_logger().info('Code suggestions generated for PR, but not published since publish_output is False.')
+ get_logger().info(
+ 'Code suggestions generated for PR, but not published since publish_output is False.'
+ )
pr_body = self.generate_summarized_suggestions(data)
get_settings().data = {"artifact": pr_body}
return
except Exception as e:
- get_logger().error(f"Failed to generate code suggestions for PR, error: {e}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Failed to generate code suggestions for PR, error: {e}",
+ artifact={"traceback": traceback.format_exc()},
+ )
if get_settings().config.publish_output:
if self.progress_response:
self.progress_response.delete()
else:
try:
self.git_provider.remove_initial_comment()
- self.git_provider.publish_comment(f"Failed to generate code suggestions for PR")
+ self.git_provider.publish_comment(
+ f"Failed to generate code suggestions for PR"
+ )
except Exception as e:
- get_logger().exception(f"Failed to update persistent review, error: {e}")
+ get_logger().exception(
+ f"Failed to update persistent review, error: {e}"
+ )
async def add_self_review_text(self, pr_body):
text = get_settings().pr_code_suggestions.code_suggestions_self_review_text
pr_body += f"\n\n- [ ] {text}"
- approve_pr_on_self_review = get_settings().pr_code_suggestions.approve_pr_on_self_review
- fold_suggestions_on_self_review = get_settings().pr_code_suggestions.fold_suggestions_on_self_review
+ approve_pr_on_self_review = (
+ get_settings().pr_code_suggestions.approve_pr_on_self_review
+ )
+ fold_suggestions_on_self_review = (
+ get_settings().pr_code_suggestions.fold_suggestions_on_self_review
+ )
if approve_pr_on_self_review and not fold_suggestions_on_self_review:
pr_body += ' '
elif fold_suggestions_on_self_review and not approve_pr_on_self_review:
@@ -214,7 +299,10 @@ class PRCodeSuggestions:
async def publish_no_suggestions(self):
pr_body = "## PR 代码建议 ✨\n\n未找到该PR的代码建议."
- if get_settings().config.publish_output and get_settings().config.publish_output_no_suggestions:
+ if (
+ get_settings().config.publish_output
+ and get_settings().config.publish_output_no_suggestions
+ ):
get_logger().warning('No code suggestions found for the PR.')
get_logger().debug(f"PR output", artifact=pr_body)
if self.progress_response:
@@ -229,31 +317,40 @@ class PRCodeSuggestions:
try:
for suggestion in data['code_suggestions']:
if int(suggestion.get('score', 0)) >= int(
- get_settings().pr_code_suggestions.dual_publishing_score_threshold) \
- and suggestion.get('improved_code'):
+ get_settings().pr_code_suggestions.dual_publishing_score_threshold
+ ) and suggestion.get('improved_code'):
data_above_threshold['code_suggestions'].append(suggestion)
- if not data_above_threshold['code_suggestions'][-1]['existing_code']:
- get_logger().info(f'Identical existing and improved code for dual publishing found')
- data_above_threshold['code_suggestions'][-1]['existing_code'] = suggestion[
- 'improved_code']
+ if not data_above_threshold['code_suggestions'][-1][
+ 'existing_code'
+ ]:
+ get_logger().info(
+ f'Identical existing and improved code for dual publishing found'
+ )
+ data_above_threshold['code_suggestions'][-1][
+ 'existing_code'
+ ] = suggestion['improved_code']
if data_above_threshold['code_suggestions']:
get_logger().info(
- f"Publishing {len(data_above_threshold['code_suggestions'])} suggestions in dual publishing mode")
+ f"Publishing {len(data_above_threshold['code_suggestions'])} suggestions in dual publishing mode"
+ )
await self.push_inline_code_suggestions(data_above_threshold)
except Exception as e:
- get_logger().error(f"Failed to publish dual publishing suggestions, error: {e}")
+ get_logger().error(
+ f"Failed to publish dual publishing suggestions, error: {e}"
+ )
@staticmethod
- def publish_persistent_comment_with_history(git_provider: GitProvider,
- pr_comment: str,
- initial_header: str,
- update_header: bool = True,
- name='review',
- final_update_message=True,
- max_previous_comments=4,
- progress_response=None,
- only_fold=False):
-
+ def publish_persistent_comment_with_history(
+ git_provider: GitProvider,
+ pr_comment: str,
+ initial_header: str,
+ update_header: bool = True,
+ name='review',
+ final_update_message=True,
+ max_previous_comments=4,
+ progress_response=None,
+ only_fold=False,
+ ):
def _extract_link(comment_text: str):
r = re.compile(r"")
match = r.search(comment_text)
@@ -263,7 +360,9 @@ class PRCodeSuggestions:
up_to_commit_txt = f" up to commit {match.group(0)[4:-3].strip()}"
return up_to_commit_txt
- if isinstance(git_provider, AzureDevopsProvider): # get_latest_commit_url is not supported yet
+ if isinstance(
+ git_provider, AzureDevopsProvider
+ ): # get_latest_commit_url is not supported yet
if progress_response:
git_provider.edit_comment(progress_response, pr_comment)
new_comment = progress_response
@@ -273,7 +372,7 @@ class PRCodeSuggestions:
history_header = f"#### Previous suggestions\n"
last_commit_num = git_provider.get_latest_commit_url().split('/')[-1][:7]
- if only_fold: # A user clicked on the 'self-review' checkbox
+ if only_fold: # A user clicked on the 'self-review' checkbox
text = get_settings().pr_code_suggestions.code_suggestions_self_review_text
latest_suggestion_header = f"\n\n- [x] {text}"
else:
@@ -300,42 +399,66 @@ class PRCodeSuggestions:
# find http link from comment.body[:table_index]
up_to_commit_txt = _extract_link(comment.body[:table_index])
prev_suggestion_table = comment.body[
- table_index:comment.body.rfind("") + len("")]
+ table_index : comment.body.rfind("")
+ + len("")
+ ]
tick = "✅ " if "✅" in prev_suggestion_table else ""
# surround with details tag
prev_suggestion_table = f"{tick}{name.capitalize()}{up_to_commit_txt}\n {prev_suggestion_table}\n\n "
- new_suggestion_table = pr_comment.replace(initial_header, "").strip()
+ new_suggestion_table = pr_comment.replace(
+ initial_header, ""
+ ).strip()
- pr_comment_updated = f"{initial_header}\n{latest_commit_html_comment}\n\n"
+ pr_comment_updated = (
+ f"{initial_header}\n{latest_commit_html_comment}\n\n"
+ )
pr_comment_updated += f"{latest_suggestion_header}\n{new_suggestion_table}\n\n___\n\n"
- pr_comment_updated += f"{history_header}{prev_suggestion_table}\n"
+ pr_comment_updated += (
+ f"{history_header}{prev_suggestion_table}\n"
+ )
else:
# get the text of the previous suggestions until the latest commit
sections = prev_suggestions.split(history_header.strip())
latest_table = sections[0].strip()
- prev_suggestion_table = sections[1].replace(history_header, "").strip()
+ prev_suggestion_table = (
+ sections[1].replace(history_header, "").strip()
+ )
# get text after the latest_suggestion_header in comment.body
table_ind = latest_table.find("")
up_to_commit_txt = _extract_link(latest_table[:table_ind])
- latest_table = latest_table[table_ind:latest_table.rfind(" ") + len("")]
+ latest_table = latest_table[
+ table_ind : latest_table.rfind("")
+ + len("")
+ ]
# enforce max_previous_comments
- count = prev_suggestions.count(f"\n{name.capitalize()}")
- count += prev_suggestions.count(f"\n✅ {name.capitalize()}")
+ count = prev_suggestions.count(
+ f"\n{name.capitalize()}"
+ )
+ count += prev_suggestions.count(
+ f"\n✅ {name.capitalize()}"
+ )
if count >= max_previous_comments:
# remove the oldest suggestion
- prev_suggestion_table = prev_suggestion_table[:prev_suggestion_table.rfind(
- f"{name.capitalize()} up to commit")]
+ prev_suggestion_table = prev_suggestion_table[
+ : prev_suggestion_table.rfind(
+ f"{name.capitalize()} up to commit"
+ )
+ ]
tick = "✅ " if "✅" in latest_table else ""
# Add to the prev_suggestions section
last_prev_table = f"\n{tick}{name.capitalize()}{up_to_commit_txt}\n {latest_table}\n\n "
- prev_suggestion_table = last_prev_table + "\n" + prev_suggestion_table
+ prev_suggestion_table = (
+ last_prev_table + "\n" + prev_suggestion_table
+ )
- new_suggestion_table = pr_comment.replace(initial_header, "").strip()
+ new_suggestion_table = pr_comment.replace(
+ initial_header, ""
+ ).strip()
pr_comment_updated = f"{initial_header}\n"
pr_comment_updated += f"{latest_commit_html_comment}\n\n"
@@ -344,16 +467,24 @@ class PRCodeSuggestions:
pr_comment_updated += f"{history_header}\n"
pr_comment_updated += f"{prev_suggestion_table}\n"
- get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
- if progress_response: # publish to 'progress_response' comment, because it refreshes immediately
- git_provider.edit_comment(progress_response, pr_comment_updated)
+ get_logger().info(
+ f"Persistent mode - updating comment {comment_url} to latest {name} message"
+ )
+ if (
+ progress_response
+ ): # publish to 'progress_response' comment, because it refreshes immediately
+ git_provider.edit_comment(
+ progress_response, pr_comment_updated
+ )
git_provider.remove_comment(comment)
comment = progress_response
else:
git_provider.edit_comment(comment, pr_comment_updated)
return comment
except Exception as e:
- get_logger().exception(f"Failed to update persistent review, error: {e}")
+ get_logger().exception(
+ f"Failed to update persistent review, error: {e}"
+ )
pass
# if we are here, we did not find a previous comment to update
@@ -366,7 +497,6 @@ class PRCodeSuggestions:
new_comment = git_provider.publish_comment(pr_comment)
return new_comment
-
def extract_link(self, s):
r = re.compile(r"")
match = r.search(s)
@@ -377,17 +507,23 @@ class PRCodeSuggestions:
return up_to_commit_txt
async def _prepare_prediction(self, model: str) -> dict:
- self.patches_diff = get_pr_diff(self.git_provider,
- self.token_handler,
- model,
- add_line_numbers_to_hunks=True,
- disable_extra_lines=False)
+ self.patches_diff = get_pr_diff(
+ self.git_provider,
+ self.token_handler,
+ model,
+ add_line_numbers_to_hunks=True,
+ disable_extra_lines=False,
+ )
self.patches_diff_list = [self.patches_diff]
- self.patches_diff_no_line_number = self.remove_line_numbers([self.patches_diff])[0]
+ self.patches_diff_no_line_number = self.remove_line_numbers(
+ [self.patches_diff]
+ )[0]
if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
- self.prediction = await self._get_prediction(model, self.patches_diff, self.patches_diff_no_line_number)
+ self.prediction = await self._get_prediction(
+ model, self.patches_diff, self.patches_diff_no_line_number
+ )
else:
get_logger().warning(f"Empty PR diff")
self.prediction = None
@@ -395,15 +531,25 @@ class PRCodeSuggestions:
data = self.prediction
return data
- async def _get_prediction(self, model: str, patches_diff: str, patches_diff_no_line_number: str) -> dict:
+ async def _get_prediction(
+ self, model: str, patches_diff: str, patches_diff_no_line_number: str
+ ) -> dict:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
variables["diff_no_line_numbers"] = patches_diff_no_line_number # update diff
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(self.pr_code_suggestions_prompt_system).render(variables)
- user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
+ system_prompt = environment.from_string(
+ self.pr_code_suggestions_prompt_system
+ ).render(variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_code_suggestions_prompt.user
+ ).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
- model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
+ model=model,
+ temperature=get_settings().config.temperature,
+ system=system_prompt,
+ user=user_prompt,
+ )
if not get_settings().config.publish_output:
get_settings().system_prompt = system_prompt
get_settings().user_prompt = user_prompt
@@ -413,8 +559,9 @@ class PRCodeSuggestions:
# self-reflect on suggestions (mandatory, since line numbers are generated now here)
model_reflection = get_settings().config.model
- response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"],
- patches_diff, model=model_reflection)
+ response_reflect = await self.self_reflect_on_suggestions(
+ data["code_suggestions"], patches_diff, model=model_reflection
+ )
if response_reflect:
await self.analyze_self_reflection_response(data, response_reflect)
else:
@@ -428,15 +575,23 @@ class PRCodeSuggestions:
async def analyze_self_reflection_response(self, data, response_reflect):
response_reflect_yaml = load_yaml(response_reflect)
code_suggestions_feedback = response_reflect_yaml.get("code_suggestions", [])
- if code_suggestions_feedback and len(code_suggestions_feedback) == len(data["code_suggestions"]):
+ if code_suggestions_feedback and len(code_suggestions_feedback) == len(
+ data["code_suggestions"]
+ ):
for i, suggestion in enumerate(data["code_suggestions"]):
try:
- suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
+ suggestion["score"] = code_suggestions_feedback[i][
+ "suggestion_score"
+ ]
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
if 'relevant_lines_start' not in suggestion:
- relevant_lines_start = code_suggestions_feedback[i].get('relevant_lines_start', -1)
- relevant_lines_end = code_suggestions_feedback[i].get('relevant_lines_end', -1)
+ relevant_lines_start = code_suggestions_feedback[i].get(
+ 'relevant_lines_start', -1
+ )
+ relevant_lines_end = code_suggestions_feedback[i].get(
+ 'relevant_lines_end', -1
+ )
suggestion['relevant_lines_start'] = relevant_lines_start
suggestion['relevant_lines_end'] = relevant_lines_end
if relevant_lines_start < 0 or relevant_lines_end < 0:
@@ -450,18 +605,29 @@ class PRCodeSuggestions:
score = int(suggestion["score"])
label = suggestion["label"].lower().strip()
label = label.replace(' ', ' ')
- suggestion_statistics_dict = {'score': score,
- 'label': label}
- get_logger().info(f"PR-Agent suggestions statistics",
- statistics=suggestion_statistics_dict, analytics=True)
+ suggestion_statistics_dict = {
+ 'score': score,
+ 'label': label,
+ }
+ get_logger().info(
+ f"PR-Agent suggestions statistics",
+ statistics=suggestion_statistics_dict,
+ analytics=True,
+ )
except Exception as e:
- get_logger().error(f"Failed to log suggestion statistics, error: {e}")
+ get_logger().error(
+ f"Failed to log suggestion statistics, error: {e}"
+ )
pass
except Exception as e: #
- get_logger().error(f"Error processing suggestion score {i}",
- artifact={"suggestion": suggestion,
- "code_suggestions_feedback": code_suggestions_feedback[i]})
+ get_logger().error(
+ f"Error processing suggestion score {i}",
+ artifact={
+ "suggestion": suggestion,
+ "code_suggestions_feedback": code_suggestions_feedback[i],
+ },
+ )
suggestion["score"] = 7
suggestion["score_why"] = ""
@@ -469,30 +635,53 @@ class PRCodeSuggestions:
try:
if suggestion['existing_code'] == suggestion['improved_code']:
get_logger().debug(
- f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
- if get_settings().pr_code_suggestions.commitable_code_suggestions:
- suggestion['improved_code'] = "" # we need 'existing_code' to locate the code in the PR
+ f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}"
+ )
+ if (
+ get_settings().pr_code_suggestions.commitable_code_suggestions
+ ):
+ suggestion[
+ 'improved_code'
+ ] = "" # we need 'existing_code' to locate the code in the PR
else:
suggestion['existing_code'] = ""
except Exception as e:
- get_logger().error(f"Error processing suggestion {i + 1}, error: {e}")
+ get_logger().error(
+ f"Error processing suggestion {i + 1}, error: {e}"
+ )
@staticmethod
def _truncate_if_needed(suggestion):
- max_code_suggestion_length = get_settings().get("PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0)
- suggestion_truncation_message = get_settings().get("PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "")
+ max_code_suggestion_length = get_settings().get(
+ "PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0
+ )
+ suggestion_truncation_message = get_settings().get(
+ "PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", ""
+ )
if max_code_suggestion_length > 0:
if len(suggestion['improved_code']) > max_code_suggestion_length:
- get_logger().info(f"Truncated suggestion from {len(suggestion['improved_code'])} "
- f"characters to {max_code_suggestion_length} characters")
- suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length]
+ get_logger().info(
+ f"Truncated suggestion from {len(suggestion['improved_code'])} "
+ f"characters to {max_code_suggestion_length} characters"
+ )
+ suggestion['improved_code'] = suggestion['improved_code'][
+ :max_code_suggestion_length
+ ]
suggestion['improved_code'] += f"\n{suggestion_truncation_message}"
return suggestion
def _prepare_pr_code_suggestions(self, predictions: str) -> Dict:
- data = load_yaml(predictions.strip(),
- keys_fix_yaml=["relevant_file", "suggestion_content", "existing_code", "improved_code"],
- first_key="code_suggestions", last_key="label")
+ data = load_yaml(
+ predictions.strip(),
+ keys_fix_yaml=[
+ "relevant_file",
+ "suggestion_content",
+ "existing_code",
+ "improved_code",
+ ],
+ first_key="code_suggestions",
+ last_key="label",
+ )
if isinstance(data, list):
data = {'code_suggestions': data}
@@ -507,24 +696,35 @@ class PRCodeSuggestions:
if key not in suggestion:
is_valid_keys = False
get_logger().debug(
- f"Skipping suggestion {i + 1}, because it does not contain '{key}':\n'{suggestion}")
+ f"Skipping suggestion {i + 1}, because it does not contain '{key}':\n'{suggestion}"
+ )
break
if not is_valid_keys:
continue
- if get_settings().get("pr_code_suggestions.focus_only_on_problems", False):
+ if get_settings().get(
+ "pr_code_suggestions.focus_only_on_problems", False
+ ):
CRITICAL_LABEL = 'critical'
- if CRITICAL_LABEL in suggestion['label'].lower(): # we want the published labels to be less declarative
+ if (
+ CRITICAL_LABEL in suggestion['label'].lower()
+ ): # we want the published labels to be less declarative
suggestion['label'] = 'possible issue'
if suggestion['one_sentence_summary'] in one_sentence_summary_list:
- get_logger().debug(f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}")
+ get_logger().debug(
+ f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}"
+ )
continue
- if 'const' in suggestion['suggestion_content'] and 'instead' in suggestion[
- 'suggestion_content'] and 'let' in suggestion['suggestion_content']:
+ if (
+ 'const' in suggestion['suggestion_content']
+ and 'instead' in suggestion['suggestion_content']
+ and 'let' in suggestion['suggestion_content']
+ ):
get_logger().debug(
- f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}")
+ f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}"
+ )
continue
if ('existing_code' in suggestion) and ('improved_code' in suggestion):
@@ -533,9 +733,12 @@ class PRCodeSuggestions:
suggestion_list.append(suggestion)
else:
get_logger().info(
- f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}")
+ f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}"
+ )
except Exception as e:
- get_logger().error(f"Error processing suggestion {i + 1}: {suggestion}, error: {e}")
+ get_logger().error(
+ f"Error processing suggestion {i + 1}: {suggestion}, error: {e}"
+ )
data['code_suggestions'] = suggestion_list
return data
@@ -546,46 +749,72 @@ class PRCodeSuggestions:
if not data['code_suggestions']:
get_logger().info('No suggestions found to improve this PR.')
if self.progress_response:
- return self.git_provider.edit_comment(self.progress_response,
- body='No suggestions found to improve this PR.')
+ return self.git_provider.edit_comment(
+ self.progress_response,
+ body='No suggestions found to improve this PR.',
+ )
else:
- return self.git_provider.publish_comment('No suggestions found to improve this PR.')
+ return self.git_provider.publish_comment(
+ 'No suggestions found to improve this PR.'
+ )
for d in data['code_suggestions']:
try:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"suggestion: {d}")
relevant_file = d['relevant_file'].strip()
- relevant_lines_start = int(d['relevant_lines_start']) # absolute position
+ relevant_lines_start = int(
+ d['relevant_lines_start']
+ ) # absolute position
relevant_lines_end = int(d['relevant_lines_end'])
content = d['suggestion_content'].rstrip()
new_code_snippet = d['improved_code'].rstrip()
label = d['label'].strip()
if new_code_snippet:
- new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet)
+ new_code_snippet = self.dedent_code(
+ relevant_file, relevant_lines_start, new_code_snippet
+ )
if d.get('score'):
- body = f"**Suggestion:** {content} [{label}, importance: {d.get('score')}]\n```suggestion\n" + new_code_snippet + "\n```"
+ body = (
+ f"**Suggestion:** {content} [{label}, importance: {d.get('score')}]\n```suggestion\n"
+ + new_code_snippet
+ + "\n```"
+ )
else:
- body = f"**Suggestion:** {content} [{label}]\n```suggestion\n" + new_code_snippet + "\n```"
- code_suggestions.append({'body': body, 'relevant_file': relevant_file,
- 'relevant_lines_start': relevant_lines_start,
- 'relevant_lines_end': relevant_lines_end,
- 'original_suggestion': d})
+ body = (
+ f"**Suggestion:** {content} [{label}]\n```suggestion\n"
+ + new_code_snippet
+ + "\n```"
+ )
+ code_suggestions.append(
+ {
+ 'body': body,
+ 'relevant_file': relevant_file,
+ 'relevant_lines_start': relevant_lines_start,
+ 'relevant_lines_end': relevant_lines_end,
+ 'original_suggestion': d,
+ }
+ )
except Exception:
get_logger().info(f"Could not parse suggestion: {d}")
is_successful = self.git_provider.publish_code_suggestions(code_suggestions)
if not is_successful:
- get_logger().info("Failed to publish code suggestions, trying to publish each suggestion separately")
+ get_logger().info(
+ "Failed to publish code suggestions, trying to publish each suggestion separately"
+ )
for code_suggestion in code_suggestions:
self.git_provider.publish_code_suggestions([code_suggestion])
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
try: # dedent code snippet
- self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
+ self.diff_files = (
+ self.git_provider.diff_files
+ if self.git_provider.diff_files
else self.git_provider.get_diff_files()
+ )
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
@@ -594,29 +823,44 @@ class PRCodeSuggestions:
if relevant_lines_start > len(file_lines):
get_logger().warning(
"Could not dedent code snippet, because relevant_lines_start is out of range",
- artifact={'filename': file.filename,
- 'file_content': file.head_file,
- 'relevant_lines_start': relevant_lines_start,
- 'new_code_snippet': new_code_snippet})
+ artifact={
+ 'filename': file.filename,
+ 'file_content': file.head_file,
+ 'relevant_lines_start': relevant_lines_start,
+ 'new_code_snippet': new_code_snippet,
+ },
+ )
return new_code_snippet
else:
original_initial_line = file_lines[relevant_lines_start - 1]
else:
- get_logger().warning("Could not dedent code snippet, because head_file is missing",
- artifact={'filename': file.filename,
- 'relevant_lines_start': relevant_lines_start,
- 'new_code_snippet': new_code_snippet})
+ get_logger().warning(
+ "Could not dedent code snippet, because head_file is missing",
+ artifact={
+ 'filename': file.filename,
+ 'relevant_lines_start': relevant_lines_start,
+ 'new_code_snippet': new_code_snippet,
+ },
+ )
return new_code_snippet
break
if original_initial_line:
suggested_initial_line = new_code_snippet.splitlines()[0]
- original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip())
- suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
+ original_initial_spaces = len(original_initial_line) - len(
+ original_initial_line.lstrip()
+ )
+ suggested_initial_spaces = len(suggested_initial_line) - len(
+ suggested_initial_line.lstrip()
+ )
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
- new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
+ new_code_snippet = textwrap.indent(
+ new_code_snippet, delta_spaces * " "
+ ).rstrip('\n')
except Exception as e:
- get_logger().error(f"Error when dedenting code snippet for file {relevant_file}, error: {e}")
+ get_logger().error(
+ f"Error when dedenting code snippet for file {relevant_file}, error: {e}"
+ )
return new_code_snippet
@@ -644,42 +888,72 @@ class PRCodeSuggestions:
# find the first letter in the line that starts with a valid letter
for j, char in enumerate(line):
if not char.isdigit():
- patches_diff_lines[i] = line[j + 1:]
+ patches_diff_lines[i] = line[j + 1 :]
break
- self.patches_diff_list_no_line_numbers.append('\n'.join(patches_diff_lines))
+ self.patches_diff_list_no_line_numbers.append(
+ '\n'.join(patches_diff_lines)
+ )
return self.patches_diff_list_no_line_numbers
except Exception as e:
- get_logger().error(f"Error removing line numbers from patches_diff_list, error: {e}")
+ get_logger().error(
+ f"Error removing line numbers from patches_diff_list, error: {e}"
+ )
return patches_diff_list
async def _prepare_prediction_extended(self, model: str) -> dict:
- self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
- max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
+ self.patches_diff_list = get_pr_multi_diffs(
+ self.git_provider,
+ self.token_handler,
+ model,
+ max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
+ )
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
- self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list)
+ self.patches_diff_list_no_line_numbers = self.remove_line_numbers(
+ self.patches_diff_list
+ )
if self.patches_diff_list:
- get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}")
+ get_logger().info(
+ f"Number of PR chunk calls: {len(self.patches_diff_list)}"
+ )
get_logger().debug(f"PR diff:", artifact=self.patches_diff_list)
# parallelize calls to AI:
if get_settings().pr_code_suggestions.parallel_calls:
prediction_list = await asyncio.gather(
- *[self._get_prediction(model, patches_diff, patches_diff_no_line_numbers) for
- patches_diff, patches_diff_no_line_numbers in
- zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers)])
+ *[
+ self._get_prediction(
+ model, patches_diff, patches_diff_no_line_numbers
+ )
+ for patches_diff, patches_diff_no_line_numbers in zip(
+ self.patches_diff_list,
+ self.patches_diff_list_no_line_numbers,
+ )
+ ]
+ )
self.prediction_list = prediction_list
else:
prediction_list = []
- for patches_diff, patches_diff_no_line_numbers in zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers):
- prediction = await self._get_prediction(model, patches_diff, patches_diff_no_line_numbers)
+ for patches_diff, patches_diff_no_line_numbers in zip(
+ self.patches_diff_list, self.patches_diff_list_no_line_numbers
+ ):
+ prediction = await self._get_prediction(
+ model, patches_diff, patches_diff_no_line_numbers
+ )
prediction_list.append(prediction)
data = {"code_suggestions": []}
- for j, predictions in enumerate(prediction_list): # each call adds an element to the list
+ for j, predictions in enumerate(
+ prediction_list
+ ): # each call adds an element to the list
if "code_suggestions" in predictions:
- score_threshold = max(1, int(get_settings().pr_code_suggestions.suggestions_score_threshold))
+ score_threshold = max(
+ 1,
+ int(
+ get_settings().pr_code_suggestions.suggestions_score_threshold
+ ),
+ )
for i, prediction in enumerate(predictions["code_suggestions"]):
try:
score = int(prediction.get("score", 1))
@@ -688,10 +962,13 @@ class PRCodeSuggestions:
else:
get_logger().info(
f"Removing suggestions {i} from call {j}, because score is {score}, and score_threshold is {score_threshold}",
- artifact=prediction)
+ artifact=prediction,
+ )
except Exception as e:
- get_logger().error(f"Error getting PR diff for suggestion {i} in call {j}, error: {e}",
- artifact={"prediction": prediction})
+ get_logger().error(
+ f"Error getting PR diff for suggestion {i} in call {j}, error: {e}",
+ artifact={"prediction": prediction},
+ )
self.data = data
else:
get_logger().warning(f"Empty PR diff list")
@@ -706,7 +983,10 @@ class PRCodeSuggestions:
pr_body += "No suggestions found to improve this PR."
return pr_body
- if get_settings().pr_code_suggestions.enable_intro_text and get_settings().config.is_auto_command:
+ if (
+ get_settings().pr_code_suggestions.enable_intro_text
+ and get_settings().config.is_auto_command
+ ):
pr_body += "Explore these optional code suggestions:\n\n"
language_extension_map_org = get_settings().language_extension_map_org
@@ -731,17 +1011,25 @@ class PRCodeSuggestions:
# sort suggestions_labels by the suggestion with the highest score
suggestions_labels = dict(
- sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
+ sorted(
+ suggestions_labels.items(),
+ key=lambda x: max([s['score'] for s in x[1]]),
+ reverse=True,
+ )
+ )
# sort the suggestions inside each label group by score
for label, suggestions in suggestions_labels.items():
- suggestions_labels[label] = sorted(suggestions, key=lambda x: x['score'], reverse=True)
+ suggestions_labels[label] = sorted(
+ suggestions, key=lambda x: x['score'], reverse=True
+ )
counter_suggestions = 0
for label, suggestions in suggestions_labels.items():
num_suggestions = len(suggestions)
- pr_body += f"""| {label.capitalize()} | \n"""
+ pr_body += (
+ f""" | {label.capitalize()} | \n"""
+ )
for i, suggestion in enumerate(suggestions):
-
relevant_file = suggestion['relevant_file'].strip()
relevant_lines_start = int(suggestion['relevant_lines_start'])
relevant_lines_end = int(suggestion['relevant_lines_end'])
@@ -752,21 +1040,25 @@ class PRCodeSuggestions:
range_str = f"[{relevant_lines_start}-{relevant_lines_end}]"
try:
- code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
- relevant_lines_end)
+ code_snippet_link = self.git_provider.get_line_link(
+ relevant_file, relevant_lines_start, relevant_lines_end
+ )
except:
code_snippet_link = ""
# add html table for each suggestion
suggestion_content = suggestion['suggestion_content'].rstrip()
CHAR_LIMIT_PER_LINE = 84
- suggestion_content = insert_br_after_x_chars(suggestion_content, CHAR_LIMIT_PER_LINE)
+ suggestion_content = insert_br_after_x_chars(
+ suggestion_content, CHAR_LIMIT_PER_LINE
+ )
# pr_body += f" {suggestion_content}"
existing_code = suggestion['existing_code'].rstrip() + "\n"
improved_code = suggestion['improved_code'].rstrip() + "\n"
- diff = difflib.unified_diff(existing_code.split('\n'),
- improved_code.split('\n'), n=999)
+ diff = difflib.unified_diff(
+ existing_code.split('\n'), improved_code.split('\n'), n=999
+ )
patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
@@ -776,10 +1068,14 @@ class PRCodeSuggestions:
pr_body += f"""\n\n"""
else:
pr_body += f""" | | \n\n"""
- suggestion_summary = suggestion['one_sentence_summary'].strip().rstrip('.')
+ suggestion_summary = (
+ suggestion['one_sentence_summary'].strip().rstrip('.')
+ )
if "'<" in suggestion_summary and ">'" in suggestion_summary:
# escape the '<' and '>' characters, otherwise they are interpreted as html tags
- get_logger().info(f"Escaped suggestion summary: {suggestion_summary}")
+ get_logger().info(
+ f"Escaped suggestion summary: {suggestion_summary}"
+ )
suggestion_summary = suggestion_summary.replace("'<", "`<")
suggestion_summary = suggestion_summary.replace(">'", ">`")
if '`' in suggestion_summary:
@@ -815,12 +1111,18 @@ class PRCodeSuggestions:
pr_body += """ | """
return pr_body
except Exception as e:
- get_logger().info(f"Failed to publish summarized code suggestions, error: {e}")
+ get_logger().info(
+ f"Failed to publish summarized code suggestions, error: {e}"
+ )
return ""
def get_score_str(self, score: int) -> str:
- th_high = get_settings().pr_code_suggestions.get('new_score_mechanism_th_high', 9)
- th_medium = get_settings().pr_code_suggestions.get('new_score_mechanism_th_medium', 7)
+ th_high = get_settings().pr_code_suggestions.get(
+ 'new_score_mechanism_th_high', 9
+ )
+ th_medium = get_settings().pr_code_suggestions.get(
+ 'new_score_mechanism_th_medium', 7
+ )
if score >= th_high:
return "高"
elif score >= th_medium:
@@ -828,12 +1130,14 @@ class PRCodeSuggestions:
else: # score < 7
return "低"
- async def self_reflect_on_suggestions(self,
- suggestion_list: List,
- patches_diff: str,
- model: str,
- prev_suggestions_str: str = "",
- dedicated_prompt: str = "") -> str:
+ async def self_reflect_on_suggestions(
+ self,
+ suggestion_list: List,
+ patches_diff: str,
+ model: str,
+ prev_suggestions_str: str = "",
+ dedicated_prompt: str = "",
+ ) -> str:
if not suggestion_list:
return ""
@@ -842,31 +1146,44 @@ class PRCodeSuggestions:
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
- variables = {'suggestion_list': suggestion_list,
- 'suggestion_str': suggestion_str,
- "diff": patches_diff,
- 'num_code_suggestions': len(suggestion_list),
- 'prev_suggestions_str': prev_suggestions_str,
- "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
- 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False)}
+ variables = {
+ 'suggestion_list': suggestion_list,
+ 'suggestion_str': suggestion_str,
+ "diff": patches_diff,
+ 'num_code_suggestions': len(suggestion_list),
+ 'prev_suggestions_str': prev_suggestions_str,
+ "is_ai_metadata": get_settings().get(
+ "config.enable_ai_metadata", False
+ ),
+ 'duplicate_prompt_examples': get_settings().config.get(
+ 'duplicate_prompt_examples', False
+ ),
+ }
environment = Environment(undefined=StrictUndefined)
if dedicated_prompt:
system_prompt_reflect = environment.from_string(
- get_settings().get(dedicated_prompt).system).render(variables)
+ get_settings().get(dedicated_prompt).system
+ ).render(variables)
user_prompt_reflect = environment.from_string(
- get_settings().get(dedicated_prompt).user).render(variables)
+ get_settings().get(dedicated_prompt).user
+ ).render(variables)
else:
system_prompt_reflect = environment.from_string(
- get_settings().pr_code_suggestions_reflect_prompt.system).render(variables)
+ get_settings().pr_code_suggestions_reflect_prompt.system
+ ).render(variables)
user_prompt_reflect = environment.from_string(
- get_settings().pr_code_suggestions_reflect_prompt.user).render(variables)
+ get_settings().pr_code_suggestions_reflect_prompt.user
+ ).render(variables)
with get_logger().contextualize(command="self_reflect_on_suggestions"):
- response_reflect, finish_reason_reflect = await self.ai_handler.chat_completion(model=model,
- system=system_prompt_reflect,
- user=user_prompt_reflect)
+ (
+ response_reflect,
+ finish_reason_reflect,
+ ) = await self.ai_handler.chat_completion(
+ model=model, system=system_prompt_reflect, user=user_prompt_reflect
+ )
except Exception as e:
get_logger().info(f"Could not reflect on suggestions, error: {e}")
return ""
- return response_reflect
\ No newline at end of file
+ return response_reflect
diff --git a/apps/utils/pr_agent/tools/pr_config.py b/apps/utils/pr_agent/tools/pr_config.py
index a00e015..cfee29e 100644
--- a/apps/utils/pr_agent/tools/pr_config.py
+++ b/apps/utils/pr_agent/tools/pr_config.py
@@ -9,6 +9,7 @@ class PRConfig:
"""
The PRConfig class is responsible for listing all configuration options available for the user.
"""
+
def __init__(self, pr_url: str, args=None, ai_handler=None):
"""
Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request.
@@ -34,20 +35,43 @@ class PRConfig:
conf_settings = Dynaconf(settings_files=[conf_file])
configuration_headers = [header.lower() for header in conf_settings.keys()]
relevant_configs = {
- header: configs for header, configs in get_settings().to_dict().items()
- if (header.lower().startswith("pr_") or header.lower().startswith("config")) and header.lower() in configuration_headers
+ header: configs
+ for header, configs in get_settings().to_dict().items()
+ if (header.lower().startswith("pr_") or header.lower().startswith("config"))
+ and header.lower() in configuration_headers
}
- skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
- 'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS',
- 'APP_NAME', 'PERSONAL_ACCESS_TOKEN', 'shared_secret', 'key', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'user_token',
- 'private_key', 'private_key_id', 'client_id', 'client_secret', 'token', 'bearer_token']
+ skip_keys = [
+ 'ai_disclaimer',
+ 'ai_disclaimer_title',
+ 'ANALYTICS_FOLDER',
+ 'secret_provider',
+ "skip_keys",
+ "app_id",
+ "redirect",
+ 'trial_prefix_message',
+ 'no_eligible_message',
+ 'identity_provider',
+ 'ALLOWED_REPOS',
+ 'APP_NAME',
+ 'PERSONAL_ACCESS_TOKEN',
+ 'shared_secret',
+ 'key',
+ 'AWS_ACCESS_KEY_ID',
+ 'AWS_SECRET_ACCESS_KEY',
+ 'user_token',
+ 'private_key',
+ 'private_key_id',
+ 'client_id',
+ 'client_secret',
+ 'token',
+ 'bearer_token',
+ ]
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
if extra_skip_keys:
skip_keys.extend(extra_skip_keys)
skip_keys_lower = [key.lower() for key in skip_keys]
-
markdown_text = " 🛠️ PR-Agent Configurations: \n\n"
markdown_text += f"\n\n```yaml\n\n"
for header, configs in relevant_configs.items():
@@ -61,5 +85,7 @@ class PRConfig:
markdown_text += " "
markdown_text += "\n```"
markdown_text += "\n \n"
- get_logger().info(f"Possible Configurations outputted to PR comment", artifact=markdown_text)
+ get_logger().info(
+ f"Possible Configurations outputted to PR comment", artifact=markdown_text
+ )
return markdown_text
diff --git a/apps/utils/pr_agent/tools/pr_description.py b/apps/utils/pr_agent/tools/pr_description.py
index 89a589b..f929e81 100644
--- a/apps/utils/pr_agent/tools/pr_description.py
+++ b/apps/utils/pr_agent/tools/pr_description.py
@@ -10,27 +10,38 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
-from utils.pr_agent.algo.pr_processing import (OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
- get_pr_diff,
- get_pr_diff_multiple_patchs,
- retry_with_fallback_models)
+from utils.pr_agent.algo.pr_processing import (
+ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
+ get_pr_diff,
+ get_pr_diff_multiple_patchs,
+ retry_with_fallback_models,
+)
from utils.pr_agent.algo.token_handler import TokenHandler
-from utils.pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens,
- get_max_tokens, get_user_labels, load_yaml,
- set_custom_labels,
- show_relevant_configurations)
+from utils.pr_agent.algo.utils import (
+ ModelType,
+ PRDescriptionHeader,
+ clip_tokens,
+ get_max_tokens,
+ get_user_labels,
+ load_yaml,
+ set_custom_labels,
+ show_relevant_configurations,
+)
from utils.pr_agent.config_loader import get_settings
-from utils.pr_agent.git_providers import (GithubProvider, get_git_provider_with_context)
+from utils.pr_agent.git_providers import GithubProvider, get_git_provider_with_context
from utils.pr_agent.git_providers.git_provider import get_main_pr_language
from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage
-from utils.pr_agent.tools.ticket_pr_compliance_check import (
- extract_and_cache_pr_tickets)
+from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
class PRDescription:
- def __init__(self, pr_url: str, args: list = None,
- ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
+ def __init__(
+ self,
+ pr_url: str,
+ args: list = None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
@@ -44,11 +55,22 @@ class PRDescription:
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.pr_id = self.git_provider.get_pr_id()
- self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"]
+ self.keys_fix = [
+ "filename:",
+ "language:",
+ "changes_summary:",
+ "changes_title:",
+ "description:",
+ "title:",
+ ]
- if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported(
- "gfm_markdown"):
- get_logger().debug(f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported.")
+ if (
+ get_settings().pr_description.enable_semantic_files_types
+ and not self.git_provider.is_supported("gfm_markdown")
+ ):
+ get_logger().debug(
+ f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported."
+ )
get_settings().pr_description.enable_semantic_files_types = False
# Initialize the AI handler
@@ -56,7 +78,9 @@ class PRDescription:
self.ai_handler.main_pr_language = self.main_pr_language
# Initialize the variables dictionary
- self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get("collapsible_file_list_threshold", 8)
+ self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get(
+ "collapsible_file_list_threshold", 8
+ )
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
@@ -69,8 +93,11 @@ class PRDescription:
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
"related_tickets": "",
- "include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
- 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
+ "include_file_summary_changes": len(self.git_provider.get_diff_files())
+ <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
+ 'duplicate_prompt_examples': get_settings().config.get(
+ 'duplicate_prompt_examples', False
+ ),
}
self.user_description = self.git_provider.get_user_description()
@@ -91,10 +118,14 @@ class PRDescription:
async def run(self):
try:
get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}")
- relevant_configs = {'pr_description': dict(get_settings().pr_description),
- 'config': dict(get_settings().config)}
+ relevant_configs = {
+ 'pr_description': dict(get_settings().pr_description),
+ 'config': dict(get_settings().config),
+ }
get_logger().debug("Relevant configs", artifacts=relevant_configs)
- if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
+ if get_settings().config.publish_output and not get_settings().config.get(
+ 'is_auto_command', False
+ ):
self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True)
# ticket extraction if exists
@@ -119,40 +150,73 @@ class PRDescription:
get_logger().debug(f"Publishing labels disabled")
if get_settings().pr_description.use_description_markers:
- pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer_with_markers()
+ (
+ pr_title,
+ pr_body,
+ changes_walkthrough,
+ pr_file_changes,
+ ) = self._prepare_pr_answer_with_markers()
else:
- pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer()
- if not self.git_provider.is_supported(
- "publish_file_comments") or not get_settings().pr_description.inline_file_summary:
+ (
+ pr_title,
+ pr_body,
+ changes_walkthrough,
+ pr_file_changes,
+ ) = self._prepare_pr_answer()
+ if (
+ not self.git_provider.is_supported("publish_file_comments")
+ or not get_settings().pr_description.inline_file_summary
+ ):
pr_body += "\n\n" + changes_walkthrough
- get_logger().debug("PR output", artifact={"title": pr_title, "body": pr_body})
+ get_logger().debug(
+ "PR output", artifact={"title": pr_title, "body": pr_body}
+ )
# Add help text if gfm_markdown is supported
- if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_description.enable_help_text:
+ if (
+ self.git_provider.is_supported("gfm_markdown")
+ and get_settings().pr_description.enable_help_text
+ ):
pr_body += " \n\n ✨ 工具使用指南: \n\n"
pr_body += HelpMessage.get_describe_usage_guide()
pr_body += "\n \n"
- elif get_settings().pr_description.enable_help_comment and self.git_provider.is_supported("gfm_markdown"):
+ elif (
+ get_settings().pr_description.enable_help_comment
+ and self.git_provider.is_supported("gfm_markdown")
+ ):
if isinstance(self.git_provider, GithubProvider):
- pr_body += ('\n\n___\n\n> 需要帮助?Type /help 如何 ... '
- '关于PR-Agent使用的任何问题,请在评论区留言.查看一下 '
- 'documentation '
- '了解更多. ')
- else: # gitlab
- pr_body += ("\n\n___\n\n需要帮助?- Type /help 如何 ... 在评论中 "
- "关于PR-Agent使用的任何问题请在此发帖. - 查看一下 "
- "documentation 了解更多. ")
+ pr_body += (
+ '\n\n___\n\n> 需要帮助?Type /help 如何 ... '
+ '关于PR-Agent使用的任何问题,请在评论区留言.查看一下 '
+ 'documentation '
+ '了解更多. '
+ )
+ else: # gitlab
+ pr_body += (
+ "\n\n___\n\n需要帮助?- Type /help 如何 ... 在评论中 "
+ "关于PR-Agent使用的任何问题请在此发帖. - 查看一下 "
+ "documentation 了解更多. "
+ )
# elif get_settings().pr_description.enable_help_comment:
# pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information'
# Output the relevant configurations if enabled
- if get_settings().get('config', {}).get('output_relevant_configurations', False):
- pr_body += show_relevant_configurations(relevant_section='pr_description')
+ if (
+ get_settings()
+ .get('config', {})
+ .get('output_relevant_configurations', False)
+ ):
+ pr_body += show_relevant_configurations(
+ relevant_section='pr_description'
+ )
if get_settings().config.publish_output:
-
# publish labels
- if get_settings().pr_description.publish_labels and pr_labels and self.git_provider.is_supported("get_labels"):
+ if (
+ get_settings().pr_description.publish_labels
+ and pr_labels
+ and self.git_provider.is_supported("get_labels")
+ ):
original_labels = self.git_provider.get_pr_labels(update=True)
get_logger().debug(f"original labels", artifact=original_labels)
user_labels = get_user_labels(original_labels)
@@ -165,20 +229,29 @@ class PRDescription:
# publish description
if get_settings().pr_description.publish_description_as_comment:
- full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
- if get_settings().pr_description.publish_description_as_comment_persistent:
- self.git_provider.publish_persistent_comment(full_markdown_description,
- initial_header="## Title",
- update_header=True,
- name="describe",
- final_update_message=False, )
+ full_markdown_description = (
+ f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
+ )
+ if (
+ get_settings().pr_description.publish_description_as_comment_persistent
+ ):
+ self.git_provider.publish_persistent_comment(
+ full_markdown_description,
+ initial_header="## Title",
+ update_header=True,
+ name="describe",
+ final_update_message=False,
+ )
else:
self.git_provider.publish_comment(full_markdown_description)
else:
self.git_provider.publish_description(pr_title, pr_body)
# publish final update message
- if (get_settings().pr_description.final_update_message and not get_settings().config.get('is_auto_command', False)):
+ if (
+ get_settings().pr_description.final_update_message
+ and not get_settings().config.get('is_auto_command', False)
+ ):
latest_commit_url = self.git_provider.get_latest_commit_url()
if latest_commit_url:
pr_url = self.git_provider.get_pr_url()
@@ -186,22 +259,40 @@ class PRDescription:
self.git_provider.publish_comment(update_comment)
self.git_provider.remove_initial_comment()
else:
- get_logger().info('PR description, but not published since publish_output is False.')
+ get_logger().info(
+ 'PR description, but not published since publish_output is False.'
+ )
get_settings().data = {"artifact": pr_body}
return
except Exception as e:
- get_logger().error(f"Error generating PR description {self.pr_id}: {e}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Error generating PR description {self.pr_id}: {e}",
+ artifact={"traceback": traceback.format_exc()},
+ )
return ""
async def _prepare_prediction(self, model: str) -> None:
- if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
- get_logger().info("Markers were enabled, but user description does not contain markers. skipping AI prediction")
+ if (
+ get_settings().pr_description.use_description_markers
+ and 'pr_agent:' not in self.user_description
+ ):
+ get_logger().info(
+ "Markers were enabled, but user description does not contain markers. skipping AI prediction"
+ )
return None
- large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
- output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
+ large_pr_handling = (
+ get_settings().pr_description.enable_large_pr_handling
+ and "pr_description_only_files_prompts" in get_settings()
+ )
+ output = get_pr_diff(
+ self.git_provider,
+ self.token_handler,
+ model,
+ large_pr_handling=large_pr_handling,
+ return_remaining_files=True,
+ )
if isinstance(output, tuple):
patches_diff, remaining_files_list = output
else:
@@ -213,14 +304,18 @@ class PRDescription:
if patches_diff:
# generate the prediction
get_logger().debug(f"PR diff", artifact=self.patches_diff)
- self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
+ self.prediction = await self._get_prediction(
+ model, patches_diff, prompt="pr_description_prompt"
+ )
# extend the prediction with additional files not shown
if get_settings().pr_description.enable_semantic_files_types:
self.prediction = await self.extend_uncovered_files(self.prediction)
else:
- get_logger().error(f"Error getting PR diff {self.pr_id}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Error getting PR diff {self.pr_id}",
+ artifact={"traceback": traceback.format_exc()},
+ )
self.prediction = None
else:
# get the diff in multiple patches, with the token handler only for the files prompt
@@ -231,9 +326,16 @@ class PRDescription:
get_settings().pr_description_only_files_prompts.system,
get_settings().pr_description_only_files_prompts.user,
)
- (patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict,
- files_in_patches_list) = get_pr_diff_multiple_patchs(
- self.git_provider, token_handler_only_files_prompt, model)
+ (
+ patches_compressed_list,
+ total_tokens_list,
+ deleted_files_list,
+ remaining_files_list,
+ file_dict,
+ files_in_patches_list,
+ ) = get_pr_diff_multiple_patchs(
+ self.git_provider, token_handler_only_files_prompt, model
+ )
# get the files prediction for each patch
if not get_settings().pr_description.async_ai_calls:
@@ -241,8 +343,9 @@ class PRDescription:
for i, patches in enumerate(patches_compressed_list): # sync calls
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
- prediction_files = await self._get_prediction(model, patches_diff,
- prompt="pr_description_only_files_prompts")
+ prediction_files = await self._get_prediction(
+ model, patches_diff, prompt="pr_description_only_files_prompts"
+ )
results.append(prediction_files)
else: # async calls
tasks = []
@@ -251,34 +354,52 @@ class PRDescription:
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
task = asyncio.create_task(
- self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
+ self._get_prediction(
+ model,
+ patches_diff,
+ prompt="pr_description_only_files_prompts",
+ )
+ )
tasks.append(task)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
file_description_str_list = []
for i, result in enumerate(results):
- prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
- if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'):
- prediction_files = prediction_files.removeprefix('pr_files:').strip()
+ prediction_files = (
+ result.strip().removeprefix('```yaml').strip('`').strip()
+ )
+ if load_yaml(
+ prediction_files, keys_fix_yaml=self.keys_fix
+ ) and prediction_files.startswith('pr_files'):
+ prediction_files = prediction_files.removeprefix(
+ 'pr_files:'
+ ).strip()
file_description_str_list.append(prediction_files)
else:
- get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files")
+ get_logger().debug(
+ f"failed to generate predictions in iteration {i + 1} for describe files"
+ )
# generate files_walkthrough string, with proper token handling
token_handler_only_description_prompt = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_description_prompts.system,
- get_settings().pr_description_only_description_prompts.user)
+ get_settings().pr_description_only_description_prompts.user,
+ )
files_walkthrough = "\n".join(file_description_str_list)
files_walkthrough_prompt = copy.deepcopy(files_walkthrough)
MAX_EXTRA_FILES_TO_PROMPT = 50
if remaining_files_list:
- files_walkthrough_prompt += "\n\nNo more token budget. Additional unprocessed files:"
+ files_walkthrough_prompt += (
+ "\n\nNo more token budget. Additional unprocessed files:"
+ )
for i, file in enumerate(remaining_files_list):
files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT:
- get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
+ get_logger().debug(
+ f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
+ )
files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
break
if deleted_files_list:
@@ -286,32 +407,57 @@ class PRDescription:
for i, file in enumerate(deleted_files_list):
files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT:
- get_logger().debug(f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
+ get_logger().debug(
+ f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
+ )
files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
break
tokens_files_walkthrough = len(
- token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt))
- total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough
+ token_handler_only_description_prompt.encoder.encode(
+ files_walkthrough_prompt
+ )
+ )
+ total_tokens = (
+ token_handler_only_description_prompt.prompt_tokens
+ + tokens_files_walkthrough
+ )
max_tokens_model = get_max_tokens(model)
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
# clip files_walkthrough to git the tokens within the limit
- files_walkthrough_prompt = clip_tokens(files_walkthrough_prompt,
- max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens,
- num_input_tokens=tokens_files_walkthrough)
+ files_walkthrough_prompt = clip_tokens(
+ files_walkthrough_prompt,
+ max_tokens_model
+ - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
+ - token_handler_only_description_prompt.prompt_tokens,
+ num_input_tokens=tokens_files_walkthrough,
+ )
# PR header inference
- get_logger().debug(f"PR diff only description", artifact=files_walkthrough_prompt)
- prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough_prompt,
- prompt="pr_description_only_description_prompts")
- prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
+ get_logger().debug(
+ f"PR diff only description", artifact=files_walkthrough_prompt
+ )
+ prediction_headers = await self._get_prediction(
+ model,
+ patches_diff=files_walkthrough_prompt,
+ prompt="pr_description_only_description_prompts",
+ )
+ prediction_headers = (
+ prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
+ )
# extend the tables with the files not shown
- files_walkthrough_extended = await self.extend_uncovered_files(files_walkthrough)
+ files_walkthrough_extended = await self.extend_uncovered_files(
+ files_walkthrough
+ )
# final processing
- self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
+ self.prediction = (
+ prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
+ )
if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix):
- get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}")
+ get_logger().error(
+ f"Error getting valid YAML in large PR handling for describe {self.pr_id}"
+ )
if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix):
get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers
@@ -321,12 +467,17 @@ class PRDescription:
prediction = original_prediction
# get the original prediction filenames
- original_prediction_loaded = load_yaml(original_prediction, keys_fix_yaml=self.keys_fix)
+ original_prediction_loaded = load_yaml(
+ original_prediction, keys_fix_yaml=self.keys_fix
+ )
if isinstance(original_prediction_loaded, list):
original_prediction_dict = {"pr_files": original_prediction_loaded}
else:
original_prediction_dict = original_prediction_loaded
- filenames_predicted = [file['filename'].strip() for file in original_prediction_dict.get('pr_files', [])]
+ filenames_predicted = [
+ file['filename'].strip()
+ for file in original_prediction_dict.get('pr_files', [])
+ ]
# extend the prediction with additional files not included in the original prediction
pr_files = self.git_provider.get_diff_files()
@@ -349,7 +500,9 @@ class PRDescription:
additional files
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
- get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}")
+ get_logger().debug(
+ f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}"
+ )
break
extra_file_yaml = f"""\
@@ -364,10 +517,18 @@ class PRDescription:
# merge the two dictionaries
if counter_extra_files > 0:
- get_logger().info(f"Adding {counter_extra_files} unprocessed extra files to table prediction")
- prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
- if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
- original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
+ get_logger().info(
+ f"Adding {counter_extra_files} unprocessed extra files to table prediction"
+ )
+ prediction_extra_dict = load_yaml(
+ prediction_extra, keys_fix_yaml=self.keys_fix
+ )
+ if isinstance(original_prediction_dict, dict) and isinstance(
+ prediction_extra_dict, dict
+ ):
+ original_prediction_dict["pr_files"].extend(
+ prediction_extra_dict["pr_files"]
+ )
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
prediction = new_yaml
@@ -379,11 +540,12 @@ class PRDescription:
get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}")
return original_prediction
-
async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction
try:
- original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix)
+ original_prediction_dict = load_yaml(
+ self.prediction, keys_fix_yaml=self.keys_fix
+ )
prediction_extra = "pr_files:"
for file in remaining_files_list:
extra_file_yaml = f"""\
@@ -397,10 +559,16 @@ class PRDescription:
additional files (token-limit)
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
- prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
+ prediction_extra_dict = load_yaml(
+ prediction_extra, keys_fix_yaml=self.keys_fix
+ )
# merge the two dictionaries
- if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
- original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
+ if isinstance(original_prediction_dict, dict) and isinstance(
+ prediction_extra_dict, dict
+ ):
+ original_prediction_dict["pr_files"].extend(
+ prediction_extra_dict["pr_files"]
+ )
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
prediction = new_yaml
@@ -409,7 +577,9 @@ class PRDescription:
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
return self.prediction
- async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
+ async def _get_prediction(
+ self, model: str, patches_diff: str, prompt="pr_description_prompt"
+ ) -> str:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
@@ -417,14 +587,18 @@ class PRDescription:
set_custom_labels(variables, self.git_provider)
self.variables = variables
- system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(self.variables)
- user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(self.variables)
+ system_prompt = environment.from_string(
+ get_settings().get(prompt, {}).get("system", "")
+ ).render(self.variables)
+ user_prompt = environment.from_string(
+ get_settings().get(prompt, {}).get("user", "")
+ ).render(self.variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
- user=user_prompt
+ user=user_prompt,
)
return response
@@ -433,7 +607,10 @@ class PRDescription:
# Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix)
- if get_settings().pr_description.add_original_user_description and self.user_description:
+ if (
+ get_settings().pr_description.add_original_user_description
+ and self.user_description
+ ):
self.data["User Description"] = self.user_description
# re-order keys
@@ -459,7 +636,11 @@ class PRDescription:
pr_labels = self.data['labels']
elif type(self.data['labels']) == str:
pr_labels = self.data['labels'].split(',')
- elif 'type' in self.data and self.data['type'] and get_settings().pr_description.publish_labels:
+ elif (
+ 'type' in self.data
+ and self.data['type']
+ and get_settings().pr_description.publish_labels
+ ):
if type(self.data['type']) == list:
pr_labels = self.data['type']
elif type(self.data['type']) == str:
@@ -474,7 +655,9 @@ class PRDescription:
if label_i in d:
pr_labels[i] = d[label_i]
except Exception as e:
- get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
+ get_logger().error(
+ f"Error converting labels to original case {self.pr_id}: {e}"
+ )
return pr_labels
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
@@ -482,13 +665,13 @@ class PRDescription:
# Remove the 'PR Title' key from the dictionary
ai_title = self.data.pop('title', self.vars["title"])
- if (not get_settings().pr_description.generate_ai_title):
+ if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
# Assign the value of the 'PR Title' key to 'title' variable
title = ai_title
-
+
body = self.user_description
if get_settings().pr_description.include_generated_by_header:
ai_header = f"### 🤖 Generated by PR Agent at {self.git_provider.last_commit_id.sha}\n\n"
@@ -514,8 +697,9 @@ class PRDescription:
pr_file_changes = []
if ai_walkthrough and not re.search(r'', body):
try:
- walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(walkthrough_gfm,
- self.file_label_dict)
+ walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(
+ walkthrough_gfm, self.file_label_dict
+ )
body = body.replace('pr_agent:walkthrough', walkthrough_gfm)
except Exception as e:
get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}")
@@ -545,7 +729,7 @@ class PRDescription:
# Remove the 'PR Title' key from the dictionary
ai_title = self.data.pop('title', self.vars["title"])
- if (not get_settings().pr_description.generate_ai_title):
+ if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
@@ -575,13 +759,20 @@ class PRDescription:
pr_body += f'- `{filename}`: {description}\n'
if self.git_provider.is_supported("gfm_markdown"):
pr_body += " \n"
- elif 'pr_files' in key.lower() and get_settings().pr_description.enable_semantic_files_types:
- changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(changes_walkthrough, value)
+ elif (
+ 'pr_files' in key.lower()
+ and get_settings().pr_description.enable_semantic_files_types
+ ):
+ changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(
+ changes_walkthrough, value
+ )
changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}"
elif key.lower().strip() == 'description':
if isinstance(value, list):
value = ', '.join(v.rstrip() for v in value)
- value = value.replace('\n-', '\n\n-').strip() # makes the bullet points more readable by adding double space
+ value = value.replace(
+ '\n-', '\n\n-'
+ ).strip() # makes the bullet points more readable by adding double space
pr_body += f"{value}\n"
else:
# if the value is a list, join its items by comma
@@ -591,24 +782,37 @@ class PRDescription:
if idx < len(self.data) - 1:
pr_body += "\n\n___\n\n"
- return title, pr_body, changes_walkthrough, pr_file_changes,
+ return (
+ title,
+ pr_body,
+ changes_walkthrough,
+ pr_file_changes,
+ )
def _prepare_file_labels(self):
file_label_dict = {}
- if (not self.data or not isinstance(self.data, dict) or
- 'pr_files' not in self.data or not self.data['pr_files']):
+ if (
+ not self.data
+ or not isinstance(self.data, dict)
+ or 'pr_files' not in self.data
+ or not self.data['pr_files']
+ ):
return file_label_dict
for file in self.data['pr_files']:
try:
required_fields = ['changes_title', 'filename', 'label']
if not all(field in file for field in required_fields):
# can happen for example if a YAML generation was interrupted in the middle (no more tokens)
- get_logger().warning(f"Missing required fields in file label dict {self.pr_id}, skipping file",
- artifact={"file": file})
+ get_logger().warning(
+ f"Missing required fields in file label dict {self.pr_id}, skipping file",
+ artifact={"file": file},
+ )
continue
if not file.get('changes_title'):
- get_logger().warning(f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
- artifact={"file": file})
+ get_logger().warning(
+ f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
+ artifact={"file": file},
+ )
continue
filename = file['filename'].replace("'", "`").replace('"', '`')
changes_summary = file.get('changes_summary', "").strip()
@@ -616,7 +820,9 @@ class PRDescription:
label = file.get('label').strip().lower()
if label not in file_label_dict:
file_label_dict[label] = []
- file_label_dict[label].append((filename, changes_title, changes_summary))
+ file_label_dict[label].append(
+ (filename, changes_title, changes_summary)
+ )
except Exception as e:
get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}")
pass
@@ -640,7 +846,9 @@ class PRDescription:
header = f"相关文件"
delta = 75
# header += " " * delta
- pr_body += f""" | {header} | """
+ pr_body += (
+ f""" | {header} | """
+ )
pr_body += """ | """
for semantic_label in value.keys():
s_label = semantic_label.strip("'").strip('"')
@@ -651,14 +859,22 @@ class PRDescription:
pr_body += f"""{len(list_tuples)} files"""
else:
pr_body += f""""""
- for filename, file_changes_title, file_change_description in list_tuples:
+ for (
+ filename,
+ file_changes_title,
+ file_change_description,
+ ) in list_tuples:
filename = filename.replace("'", "`").rstrip()
filename_publish = filename.split("/")[-1]
if file_changes_title and file_changes_title.strip() != "...":
file_changes_title_code = f"{file_changes_title}"
- file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip()
+ file_changes_title_code_br = insert_br_after_x_chars(
+ file_changes_title_code, x=(delta - 5)
+ ).strip()
if len(file_changes_title_code_br) < (delta - 5):
- file_changes_title_code_br += " " * ((delta - 5) - len(file_changes_title_code_br))
+ file_changes_title_code_br += " " * (
+ (delta - 5) - len(file_changes_title_code_br)
+ )
filename_publish = f"{filename_publish}{file_changes_title_code_br}"
else:
filename_publish = f"{filename_publish}"
@@ -679,15 +895,30 @@ class PRDescription:
link = ""
if hasattr(self.git_provider, 'get_line_link'):
filename = filename.strip()
- link = self.git_provider.get_line_link(filename, relevant_line_start=-1)
- if (not link or not diff_plus_minus) and ('additional files' not in filename.lower()):
- get_logger().warning(f"Error getting line link for '{filename}'")
+ link = self.git_provider.get_line_link(
+ filename, relevant_line_start=-1
+ )
+ if (not link or not diff_plus_minus) and (
+ 'additional files' not in filename.lower()
+ ):
+ get_logger().warning(
+ f"Error getting line link for '{filename}'"
+ )
continue
# Add file data to the PR body
- file_change_description_br = insert_br_after_x_chars(file_change_description, x=(delta - 5))
- pr_body = self.add_file_data(delta_nbsp, diff_plus_minus, file_change_description_br, filename,
- filename_publish, link, pr_body)
+ file_change_description_br = insert_br_after_x_chars(
+ file_change_description, x=(delta - 5)
+ )
+ pr_body = self.add_file_data(
+ delta_nbsp,
+ diff_plus_minus,
+ file_change_description_br,
+ filename,
+ filename_publish,
+ link,
+ pr_body,
+ )
# Close the collapsible file list
if use_collapsible_file_list:
@@ -697,13 +928,22 @@ class PRDescription:
pr_body += """ """
except Exception as e:
- get_logger().error(f"Error processing pr files to markdown {self.pr_id}: {str(e)}")
+ get_logger().error(
+ f"Error processing pr files to markdown {self.pr_id}: {str(e)}"
+ )
pass
return pr_body, pr_comments
- def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link,
- pr_body) -> str:
-
+ def add_file_data(
+ self,
+ delta_nbsp,
+ diff_plus_minus,
+ file_change_description_br,
+ filename,
+ filename_publish,
+ link,
+ pr_body,
+ ) -> str:
if not file_change_description_br:
pr_body += f"""
|
@@ -735,6 +975,7 @@ class PRDescription:
"""
return pr_body
+
def count_chars_without_html(string):
if '<' not in string:
return len(string)
diff --git a/apps/utils/pr_agent/tools/pr_generate_labels.py b/apps/utils/pr_agent/tools/pr_generate_labels.py
index 85158e0..1eeabe7 100644
--- a/apps/utils/pr_agent/tools/pr_generate_labels.py
+++ b/apps/utils/pr_agent/tools/pr_generate_labels.py
@@ -16,8 +16,12 @@ from utils.pr_agent.log import get_logger
class PRGenerateLabels:
- def __init__(self, pr_url: str, args: list = None,
- ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
+ def __init__(
+ self,
+ pr_url: str,
+ args: list = None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
"""
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model.
@@ -93,7 +97,9 @@ class PRGenerateLabels:
elif pr_labels:
value = ', '.join(v for v in pr_labels)
pr_labels_text = f"## PR Labels:\n{value}\n"
- self.git_provider.publish_comment(pr_labels_text, is_temporary=False)
+ self.git_provider.publish_comment(
+ pr_labels_text, is_temporary=False
+ )
self.git_provider.remove_initial_comment()
except Exception as e:
get_logger().error(f"Error generating PR labels {self.pr_id}: {e}")
@@ -137,14 +143,18 @@ class PRGenerateLabels:
set_custom_labels(variables, self.git_provider)
self.variables = variables
- system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(self.variables)
- user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(self.variables)
+ system_prompt = environment.from_string(
+ get_settings().pr_custom_labels_prompt.system
+ ).render(self.variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_custom_labels_prompt.user
+ ).render(self.variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
- user=user_prompt
+ user=user_prompt,
)
return response
@@ -153,8 +163,6 @@ class PRGenerateLabels:
# Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip())
-
-
def _prepare_labels(self) -> List[str]:
pr_types = []
@@ -174,6 +182,8 @@ class PRGenerateLabels:
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
- get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
+ get_logger().error(
+ f"Error converting labels to original case {self.pr_id}: {e}"
+ )
return pr_types
diff --git a/apps/utils/pr_agent/tools/pr_help_message.py b/apps/utils/pr_agent/tools/pr_help_message.py
index ca83b46..fda1d0e 100644
--- a/apps/utils/pr_agent/tools/pr_help_message.py
+++ b/apps/utils/pr_agent/tools/pr_help_message.py
@@ -12,7 +12,11 @@ from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
from utils.pr_agent.config_loader import get_settings
-from utils.pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context
+from utils.pr_agent.git_providers import (
+ BitbucketServerProvider,
+ GithubProvider,
+ get_git_provider_with_context,
+)
from utils.pr_agent.log import get_logger
@@ -29,31 +33,50 @@ def extract_header(snippet):
res = f"#{highest_header.lower().replace(' ', '-')}"
return res
+
class PRHelpMessage:
- def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False):
+ def __init__(
+ self,
+ pr_url: str,
+ args=None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ return_as_string=False,
+ ):
self.git_provider = get_git_provider_with_context(pr_url)
self.ai_handler = ai_handler()
self.question_str = self.parse_args(args)
self.return_as_string = return_as_string
- self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5)
+ self.num_retrieved_snippets = get_settings().get(
+ 'pr_help.num_retrieved_snippets', 5
+ )
if self.question_str:
self.vars = {
"question": self.question_str,
"snippets": "",
}
- self.token_handler = TokenHandler(None,
- self.vars,
- get_settings().pr_help_prompts.system,
- get_settings().pr_help_prompts.user)
+ self.token_handler = TokenHandler(
+ None,
+ self.vars,
+ get_settings().pr_help_prompts.system,
+ get_settings().pr_help_prompts.user,
+ )
async def _prepare_prediction(self, model: str):
try:
variables = copy.deepcopy(self.vars)
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(get_settings().pr_help_prompts.system).render(variables)
- user_prompt = environment.from_string(get_settings().pr_help_prompts.user).render(variables)
+ system_prompt = environment.from_string(
+ get_settings().pr_help_prompts.system
+ ).render(variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_help_prompts.user
+ ).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
- model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
+ model=model,
+ temperature=get_settings().config.temperature,
+ system=system_prompt,
+ user=user_prompt,
+ )
return response
except Exception as e:
get_logger().error(f"Error while preparing prediction: {e}")
@@ -81,7 +104,7 @@ class PRHelpMessage:
'.': '',
'?': '',
'!': '',
- ' ': '-'
+ ' ': '-',
}
# Compile regex pattern for characters to remove
@@ -90,37 +113,69 @@ class PRHelpMessage:
# Perform replacements in a single pass and convert to lowercase
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
except Exception:
- get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header})
+ get_logger().exception(
+ f"Error while formatting markdown header", artifacts={'header': header}
+ )
return ""
-
async def run(self):
try:
if self.question_str:
- get_logger().info(f'Answering a PR question about the PR {self.git_provider.pr_url} ')
+ get_logger().info(
+ f'Answering a PR question about the PR {self.git_provider.pr_url} '
+ )
if not get_settings().get('openai.key'):
if get_settings().config.publish_output:
self.git_provider.publish_comment(
- "The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
+ "The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
+ )
else:
- get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
+ get_logger().error(
+ "The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
+ )
return
# current path
- docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs'
+ docs_path = Path(__file__).parent.parent.parent / 'docs' / 'docs'
# get all the 'md' files inside docs_path and its subdirectories
md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude = ['/finetuning_benchmark/']
- files_to_exclude = {'EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md'}
- md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
+ files_to_exclude = {
+ 'EXAMPLE_BEST_PRACTICE.md',
+ 'compression_strategy.md',
+ '/docs/overview/index.md',
+ }
+ md_files = [
+ file
+ for file in md_files
+ if not any(folder in str(file) for folder in folders_to_exclude)
+ and not any(
+ file.name == file_to_exclude
+ for file_to_exclude in files_to_exclude
+ )
+ ]
# sort the 'md_files' so that 'priority_files' will be at the top
- priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
- 'tools/improve.md', '/faq']
- md_files_priority = [file for file in md_files if
- any(priority_string in str(file) for priority_string in priority_files_strings)]
- md_files_not_priority = [file for file in md_files if file not in md_files_priority]
+ priority_files_strings = [
+ '/docs/index.md',
+ '/usage-guide',
+ 'tools/describe.md',
+ 'tools/review.md',
+ 'tools/improve.md',
+ '/faq',
+ ]
+ md_files_priority = [
+ file
+ for file in md_files
+ if any(
+ priority_string in str(file)
+ for priority_string in priority_files_strings
+ )
+ ]
+ md_files_not_priority = [
+ file for file in md_files if file not in md_files_priority
+ ]
md_files = md_files_priority + md_files_not_priority
docs_prompt = ""
@@ -132,24 +187,36 @@ class PRHelpMessage:
except Exception as e:
get_logger().error(f"Error while reading the file {file}: {e}")
token_count = self.token_handler.count_tokens(docs_prompt)
- get_logger().debug(f"Token count of full documentation website: {token_count}")
+ get_logger().debug(
+ f"Token count of full documentation website: {token_count}"
+ )
model = get_settings().config.model
if model in MAX_TOKENS:
- max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
+ max_tokens_full = MAX_TOKENS[
+ model
+ ] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
else:
max_tokens_full = get_max_tokens(model)
delta_output = 2000
if token_count > max_tokens_full - delta_output:
- get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
- docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
+ get_logger().info(
+ f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message."
+ )
+ docs_prompt = clip_tokens(
+ docs_prompt, max_tokens_full - delta_output
+ )
self.vars['snippets'] = docs_prompt.strip()
# run the AI model
- response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
+ response = await retry_with_fallback_models(
+ self._prepare_prediction, model_type=ModelType.REGULAR
+ )
response_yaml = load_yaml(response)
if isinstance(response_yaml, str):
- get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is")
+ get_logger().warning(
+ f"failing to parse response: {response_yaml}, publishing the response as is"
+ )
if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n"
@@ -160,7 +227,9 @@ class PRHelpMessage:
relevant_sections = response_yaml.get('relevant_sections')
if not relevant_sections:
- get_logger().info(f"Could not find relevant answer for the question: {self.question_str}")
+ get_logger().info(
+ f"Could not find relevant answer for the question: {self.question_str}"
+ )
if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n"
@@ -178,29 +247,38 @@ class PRHelpMessage:
for section in relevant_sections:
file = section.get('file_name').strip().removesuffix('.md')
if str(section['relevant_section_header_string']).strip():
- markdown_header = self.format_markdown_header(section['relevant_section_header_string'])
+ markdown_header = self.format_markdown_header(
+ section['relevant_section_header_string']
+ )
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
else:
answer_str += f"> - {base_path}{file}\n"
-
# publish the answer
if get_settings().config.publish_output:
self.git_provider.publish_comment(answer_str)
else:
get_logger().info(f"Answer:\n{answer_str}")
else:
- if not isinstance(self.git_provider, BitbucketServerProvider) and not self.git_provider.is_supported("gfm_markdown"):
+ if not isinstance(
+ self.git_provider, BitbucketServerProvider
+ ) and not self.git_provider.is_supported("gfm_markdown"):
self.git_provider.publish_comment(
- "The `Help` tool requires gfm markdown, which is not supported by your code platform.")
+ "The `Help` tool requires gfm markdown, which is not supported by your code platform."
+ )
return
get_logger().info('Getting PR Help Message...')
- relevant_configs = {'pr_help': dict(get_settings().pr_help),
- 'config': dict(get_settings().config)}
+ relevant_configs = {
+ 'pr_help': dict(get_settings().pr_help),
+ 'config': dict(get_settings().config),
+ }
get_logger().debug("Relevant configs", artifacts=relevant_configs)
pr_comment = "## PR Agent Walkthrough 🤖\n\n"
- pr_comment += "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."""
+ pr_comment += (
+ "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."
+ ""
+ )
pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n"
base_path = "https://pr-agent-docs.codium.ai/tools"
@@ -211,32 +289,58 @@ class PRHelpMessage:
tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)")
tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎")
tool_names.append(f"[TEST]({base_path}/test/) 💎")
- tool_names.append(f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎")
+ tool_names.append(
+ f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎"
+ )
tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎")
tool_names.append(f"[ASK]({base_path}/ask/)")
tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)")
- tool_names.append(f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎")
+ tool_names.append(
+ f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎"
+ )
tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎")
tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎")
tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎")
descriptions = []
- descriptions.append("Generates PR description - title, type, summary, code walkthrough and labels")
- descriptions.append("Adjustable feedback about the PR, possible issues, security concerns, review effort and more")
+ descriptions.append(
+ "Generates PR description - title, type, summary, code walkthrough and labels"
+ )
+ descriptions.append(
+ "Adjustable feedback about the PR, possible issues, security concerns, review effort and more"
+ )
descriptions.append("Code suggestions for improving the PR")
descriptions.append("Automatically updates the changelog")
- descriptions.append("Generates documentation to methods/functions/classes that changed in the PR")
- descriptions.append("Generates unit tests for a specific component, based on the PR code change")
- descriptions.append("Code suggestions for a specific component that changed in the PR")
- descriptions.append("Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component")
+ descriptions.append(
+ "Generates documentation to methods/functions/classes that changed in the PR"
+ )
+ descriptions.append(
+ "Generates unit tests for a specific component, based on the PR code change"
+ )
+ descriptions.append(
+ "Code suggestions for a specific component that changed in the PR"
+ )
+ descriptions.append(
+ "Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component"
+ )
descriptions.append("Answering free-text questions about the PR")
- descriptions.append("Automatically retrieves and presents similar issues")
- descriptions.append("Generates custom labels for the PR, based on specific guidelines defined by the user")
- descriptions.append("Generates feedback and analysis for a failed CI job")
- descriptions.append("Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user")
- descriptions.append("Generates implementation code from review suggestions")
+ descriptions.append(
+ "Automatically retrieves and presents similar issues"
+ )
+ descriptions.append(
+ "Generates custom labels for the PR, based on specific guidelines defined by the user"
+ )
+ descriptions.append(
+ "Generates feedback and analysis for a failed CI job"
+ )
+ descriptions.append(
+ "Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user"
+ )
+ descriptions.append(
+ "Generates implementation code from review suggestions"
+ )
- commands =[]
+ commands = []
commands.append("`/describe`")
commands.append("`/review`")
commands.append("`/improve`")
@@ -271,7 +375,9 @@ class PRHelpMessage:
checkbox_list.append("[*]")
checkbox_list.append("[*]")
- if isinstance(self.git_provider, GithubProvider) and not get_settings().config.get('disable_checkboxes', False):
+ if isinstance(
+ self.git_provider, GithubProvider
+ ) and not get_settings().config.get('disable_checkboxes', False):
pr_comment += f"| Tool | Description | Trigger Interactively :gem: | "
for i in range(len(tool_names)):
pr_comment += f"\n| \n\n{tool_names[i]} | \n{descriptions[i]} | \n\n\n{checkbox_list[i]}\n | "
diff --git a/apps/utils/pr_agent/tools/pr_line_questions.py b/apps/utils/pr_agent/tools/pr_line_questions.py
index 5067be1..60d330a 100644
--- a/apps/utils/pr_agent/tools/pr_line_questions.py
+++ b/apps/utils/pr_agent/tools/pr_line_questions.py
@@ -5,8 +5,7 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
-from utils.pr_agent.algo.git_patch_processing import (
- extract_hunk_lines_from_patch)
+from utils.pr_agent.algo.git_patch_processing import extract_hunk_lines_from_patch
from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import ModelType
@@ -17,7 +16,12 @@ from utils.pr_agent.log import get_logger
class PR_LineQuestions:
- def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
+ def __init__(
+ self,
+ pr_url: str,
+ args=None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
self.question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
@@ -34,10 +38,12 @@ class PR_LineQuestions:
"full_hunk": "",
"selected_lines": "",
}
- self.token_handler = TokenHandler(self.git_provider.pr,
- self.vars,
- get_settings().pr_line_questions_prompt.system,
- get_settings().pr_line_questions_prompt.user)
+ self.token_handler = TokenHandler(
+ self.git_provider.pr,
+ self.vars,
+ get_settings().pr_line_questions_prompt.system,
+ get_settings().pr_line_questions_prompt.user,
+ )
self.patches_diff = None
self.prediction = None
@@ -48,7 +54,6 @@ class PR_LineQuestions:
question_str = ""
return question_str
-
async def run(self):
get_logger().info('Answering a PR lines question...')
# if get_settings().config.publish_output:
@@ -62,22 +67,27 @@ class PR_LineQuestions:
file_name = get_settings().get('file_name', '')
comment_id = get_settings().get('comment_id', '')
if ask_diff:
- self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(ask_diff,
- file_name,
- line_start=line_start,
- line_end=line_end,
- side=side
- )
+ self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(
+ ask_diff, file_name, line_start=line_start, line_end=line_end, side=side
+ )
else:
diff_files = self.git_provider.get_diff_files()
for file in diff_files:
if file.filename == file_name:
- self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(file.patch, file.filename,
- line_start=line_start,
- line_end=line_end,
- side=side)
+ (
+ self.patch_with_lines,
+ self.selected_lines,
+ ) = extract_hunk_lines_from_patch(
+ file.patch,
+ file.filename,
+ line_start=line_start,
+ line_end=line_end,
+ side=side,
+ )
if self.patch_with_lines:
- model_answer = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK)
+ model_answer = await retry_with_fallback_models(
+ self._get_prediction, model_type=ModelType.WEAK
+ )
# sanitize the answer so that no line will start with "/"
model_answer_sanitized = model_answer.strip().replace("\n/", "\n /")
if model_answer_sanitized.startswith("/"):
@@ -85,7 +95,9 @@ class PR_LineQuestions:
get_logger().info('Preparing answer...')
if comment_id:
- self.git_provider.reply_to_comment_from_comment_id(comment_id, model_answer_sanitized)
+ self.git_provider.reply_to_comment_from_comment_id(
+ comment_id, model_answer_sanitized
+ )
else:
self.git_provider.publish_comment(model_answer_sanitized)
@@ -96,8 +108,12 @@ class PR_LineQuestions:
variables["full_hunk"] = self.patch_with_lines # update diff
variables["selected_lines"] = self.selected_lines
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(get_settings().pr_line_questions_prompt.system).render(variables)
- user_prompt = environment.from_string(get_settings().pr_line_questions_prompt.user).render(variables)
+ system_prompt = environment.from_string(
+ get_settings().pr_line_questions_prompt.system
+ ).render(variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_line_questions_prompt.user
+ ).render(variables)
if get_settings().config.verbosity_level >= 2:
# get_logger().info(f"\nSystem prompt:\n{system_prompt}")
# get_logger().info(f"\nUser prompt:\n{user_prompt}")
@@ -105,5 +121,9 @@ class PR_LineQuestions:
print(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(
- model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
+ model=model,
+ temperature=get_settings().config.temperature,
+ system=system_prompt,
+ user=user_prompt,
+ )
return response
diff --git a/apps/utils/pr_agent/tools/pr_questions.py b/apps/utils/pr_agent/tools/pr_questions.py
index a1dae7b..081e03b 100644
--- a/apps/utils/pr_agent/tools/pr_questions.py
+++ b/apps/utils/pr_agent/tools/pr_questions.py
@@ -16,7 +16,12 @@ from utils.pr_agent.servers.help import HelpMessage
class PRQuestions:
- def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
+ def __init__(
+ self,
+ pr_url: str,
+ args=None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
question_str = self.parse_args(args)
self.pr_url = pr_url
self.git_provider = get_git_provider()(pr_url)
@@ -36,10 +41,12 @@ class PRQuestions:
"questions": self.question_str,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
- self.token_handler = TokenHandler(self.git_provider.pr,
- self.vars,
- get_settings().pr_questions_prompt.system,
- get_settings().pr_questions_prompt.user)
+ self.token_handler = TokenHandler(
+ self.git_provider.pr,
+ self.vars,
+ get_settings().pr_questions_prompt.system,
+ get_settings().pr_questions_prompt.user,
+ )
self.patches_diff = None
self.prediction = None
@@ -52,8 +59,10 @@ class PRQuestions:
async def run(self):
get_logger().info(f'Answering a PR question about the PR {self.pr_url} ')
- relevant_configs = {'pr_questions': dict(get_settings().pr_questions),
- 'config': dict(get_settings().config)}
+ relevant_configs = {
+ 'pr_questions': dict(get_settings().pr_questions),
+ 'config': dict(get_settings().config),
+ }
get_logger().debug("Relevant configs", artifacts=relevant_configs)
if get_settings().config.publish_output:
self.git_provider.publish_comment("思考回答中...", is_temporary=True)
@@ -63,12 +72,17 @@ class PRQuestions:
if img_path:
get_logger().debug(f"Image path identified", artifact=img_path)
- await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
+ await retry_with_fallback_models(
+ self._prepare_prediction, model_type=ModelType.WEAK
+ )
pr_comment = self._prepare_pr_answer()
get_logger().debug(f"PR output", artifact=pr_comment)
- if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_questions.enable_help_text:
+ if (
+ self.git_provider.is_supported("gfm_markdown")
+ and get_settings().pr_questions.enable_help_text
+ ):
pr_comment += " \n\n 💡 Tool usage guide: \n\n"
pr_comment += HelpMessage.get_ask_usage_guide()
pr_comment += "\n \n"
@@ -85,7 +99,9 @@ class PRQuestions:
# /ask question ... > 
img_path = self.question_str.split('![image]')[1].strip().strip('()')
self.vars['img_path'] = img_path
- elif 'https://' in self.question_str and ('.png' in self.question_str or 'jpg' in self.question_str): # direct image link
+ elif 'https://' in self.question_str and (
+ '.png' in self.question_str or 'jpg' in self.question_str
+ ): # direct image link
# include https:// in the image path
img_path = 'https://' + self.question_str.split('https://')[1]
self.vars['img_path'] = img_path
@@ -104,16 +120,28 @@ class PRQuestions:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
- user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
+ system_prompt = environment.from_string(
+ get_settings().pr_questions_prompt.system
+ ).render(variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_questions_prompt.user
+ ).render(variables)
if 'img_path' in variables:
img_path = self.vars['img_path']
- response, finish_reason = await (self.ai_handler.chat_completion
- (model=model, temperature=get_settings().config.temperature,
- system=system_prompt, user=user_prompt, img_path=img_path))
+ response, finish_reason = await self.ai_handler.chat_completion(
+ model=model,
+ temperature=get_settings().config.temperature,
+ system=system_prompt,
+ user=user_prompt,
+ img_path=img_path,
+ )
else:
response, finish_reason = await self.ai_handler.chat_completion(
- model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
+ model=model,
+ temperature=get_settings().config.temperature,
+ system=system_prompt,
+ user=user_prompt,
+ )
return response
def _prepare_pr_answer(self) -> str:
@@ -123,9 +151,13 @@ class PRQuestions:
if model_answer_sanitized.startswith("/"):
model_answer_sanitized = " " + model_answer_sanitized
if model_answer_sanitized != model_answer:
- get_logger().debug(f"Sanitized model answer",
- artifact={"model_answer": model_answer, "sanitized_answer": model_answer_sanitized})
-
+ get_logger().debug(
+ f"Sanitized model answer",
+ artifact={
+ "model_answer": model_answer,
+ "sanitized_answer": model_answer_sanitized,
+ },
+ )
answer_str = f"### **Ask**❓\n{self.question_str}\n\n"
answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n"
diff --git a/apps/utils/pr_agent/tools/pr_reviewer.py b/apps/utils/pr_agent/tools/pr_reviewer.py
index e21628f..ef4b7dd 100644
--- a/apps/utils/pr_agent/tools/pr_reviewer.py
+++ b/apps/utils/pr_agent/tools/pr_reviewer.py
@@ -7,21 +7,29 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
-from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
- get_pr_diff,
- retry_with_fallback_models)
+from utils.pr_agent.algo.pr_processing import (
+ add_ai_metadata_to_diff_files,
+ get_pr_diff,
+ retry_with_fallback_models,
+)
from utils.pr_agent.algo.token_handler import TokenHandler
-from utils.pr_agent.algo.utils import (ModelType, PRReviewHeader,
- convert_to_markdown_v2, github_action_output,
- load_yaml, show_relevant_configurations)
+from utils.pr_agent.algo.utils import (
+ ModelType,
+ PRReviewHeader,
+ convert_to_markdown_v2,
+ github_action_output,
+ load_yaml,
+ show_relevant_configurations,
+)
from utils.pr_agent.config_loader import get_settings
-from utils.pr_agent.git_providers import (get_git_provider_with_context)
-from utils.pr_agent.git_providers.git_provider import (IncrementalPR,
- get_main_pr_language)
+from utils.pr_agent.git_providers import get_git_provider_with_context
+from utils.pr_agent.git_providers.git_provider import (
+ IncrementalPR,
+ get_main_pr_language,
+)
from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage
-from utils.pr_agent.tools.ticket_pr_compliance_check import (
- extract_and_cache_pr_tickets)
+from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
class PRReviewer:
@@ -29,8 +37,14 @@ class PRReviewer:
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
- def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
- ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
+ def __init__(
+ self,
+ pr_url: str,
+ is_answer: bool = False,
+ is_auto: bool = False,
+ args: list = None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
@@ -55,16 +69,23 @@ class PRReviewer:
self.is_auto = is_auto
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
- raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
+ raise Exception(
+ f"Answer mode is not supported for {get_settings().config.git_provider} for now"
+ )
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
answer_str, question_str = self._get_user_answers()
- self.pr_description, self.pr_description_files = (
- self.git_provider.get_pr_description(split_changes_walkthrough=True))
- if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
- get_settings().get("config.enable_ai_metadata", False)):
+ (
+ self.pr_description,
+ self.pr_description_files,
+ ) = self.git_provider.get_pr_description(split_changes_walkthrough=True)
+ if (
+ self.pr_description_files
+ and get_settings().get("config.is_auto_command", False)
+ and get_settings().get("config.enable_ai_metadata", False)
+ ):
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
get_logger().debug(f"AI metadata added to the this command")
else:
@@ -89,9 +110,11 @@ class PRReviewer:
"commit_messages_str": self.git_provider.get_commit_messages(),
"custom_labels": "",
"enable_custom_labels": get_settings().config.enable_custom_labels,
- "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
+ "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
"related_tickets": get_settings().get('related_tickets', []),
- 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
+ 'duplicate_prompt_examples': get_settings().config.get(
+ 'duplicate_prompt_examples', False
+ ),
"date": datetime.datetime.now().strftime('%Y-%m-%d'),
}
@@ -99,7 +122,7 @@ class PRReviewer:
self.git_provider.pr,
self.vars,
get_settings().pr_review_prompt.system,
- get_settings().pr_review_prompt.user
+ get_settings().pr_review_prompt.user,
)
def parse_incremental(self, args: List[str]):
@@ -117,7 +140,10 @@ class PRReviewer:
get_logger().info(f"PR has no files: {self.pr_url}, skipping review")
return None
- if self.incremental.is_incremental and not self._can_run_incremental_review():
+ if (
+ self.incremental.is_incremental
+ and not self._can_run_incremental_review()
+ ):
return None
# if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve':
@@ -126,27 +152,41 @@ class PRReviewer:
# return None
get_logger().info(f'Reviewing PR: {self.pr_url} ...')
- relevant_configs = {'pr_reviewer': dict(get_settings().pr_reviewer),
- 'config': dict(get_settings().config)}
+ relevant_configs = {
+ 'pr_reviewer': dict(get_settings().pr_reviewer),
+ 'config': dict(get_settings().config),
+ }
get_logger().debug("Relevant configs", artifacts=relevant_configs)
# ticket extraction if exists
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
- if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set:
- get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files")
+ if (
+ self.incremental.is_incremental
+ and hasattr(self.git_provider, "unreviewed_files_set")
+ and not self.git_provider.unreviewed_files_set
+ ):
+ get_logger().info(
+ f"Incremental review is enabled for {self.pr_url} but there are no new files"
+ )
previous_review_url = ""
if hasattr(self.git_provider, "previous_review"):
previous_review_url = self.git_provider.previous_review.html_url
if get_settings().config.publish_output:
- self.git_provider.publish_comment(f"Incremental Review Skipped\n"
- f"No files were changed since the [previous PR Review]({previous_review_url})")
+ self.git_provider.publish_comment(
+ f"Incremental Review Skipped\n"
+ f"No files were changed since the [previous PR Review]({previous_review_url})"
+ )
return None
- if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
+ if get_settings().config.publish_output and not get_settings().config.get(
+ 'is_auto_command', False
+ ):
self.git_provider.publish_comment("准备评审中...", is_temporary=True)
- await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
+ await retry_with_fallback_models(
+ self._prepare_prediction, model_type=ModelType.REGULAR
+ )
if not self.prediction:
self.git_provider.remove_initial_comment()
return None
@@ -156,12 +196,19 @@ class PRReviewer:
if get_settings().config.publish_output:
# publish the review
- if get_settings().pr_reviewer.persistent_comment and not self.incremental.is_incremental:
- final_update_message = get_settings().pr_reviewer.final_update_message
- self.git_provider.publish_persistent_comment(pr_review,
- initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
- update_header=True,
- final_update_message=final_update_message, )
+ if (
+ get_settings().pr_reviewer.persistent_comment
+ and not self.incremental.is_incremental
+ ):
+ final_update_message = (
+ get_settings().pr_reviewer.final_update_message
+ )
+ self.git_provider.publish_persistent_comment(
+ pr_review,
+ initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
+ update_header=True,
+ final_update_message=final_update_message,
+ )
else:
self.git_provider.publish_comment(pr_review)
@@ -174,11 +221,13 @@ class PRReviewer:
get_logger().error(f"Failed to review PR: {e}")
async def _prepare_prediction(self, model: str) -> None:
- self.patches_diff = get_pr_diff(self.git_provider,
- self.token_handler,
- model,
- add_line_numbers_to_hunks=True,
- disable_extra_lines=False,)
+ self.patches_diff = get_pr_diff(
+ self.git_provider,
+ self.token_handler,
+ model,
+ add_line_numbers_to_hunks=True,
+ disable_extra_lines=False,
+ )
if self.patches_diff:
get_logger().debug(f"PR diff", diff=self.patches_diff)
@@ -201,14 +250,18 @@ class PRReviewer:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
- user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
+ system_prompt = environment.from_string(
+ get_settings().pr_review_prompt.system
+ ).render(variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_review_prompt.user
+ ).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
- user=user_prompt
+ user=user_prompt,
)
return response
@@ -220,10 +273,20 @@ class PRReviewer:
"""
first_key = 'review'
last_key = 'security_concerns'
- data = load_yaml(self.prediction.strip(),
- keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
- "relevant_file:", "relevant_line:", "suggestion:"],
- first_key=first_key, last_key=last_key)
+ data = load_yaml(
+ self.prediction.strip(),
+ keys_fix_yaml=[
+ "ticket_compliance_check",
+ "estimated_effort_to_review_[1-5]:",
+ "security_concerns:",
+ "key_issues_to_review:",
+ "relevant_file:",
+ "relevant_line:",
+ "suggestion:",
+ ],
+ first_key=first_key,
+ last_key=last_key,
+ )
github_action_output(data, 'review')
# move data['review'] 'key_issues_to_review' key to the end of the dictionary
@@ -234,24 +297,38 @@ class PRReviewer:
incremental_review_markdown_text = None
# Add incremental review section
if self.incremental.is_incremental:
- last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
- f"{self.git_provider.incremental.first_new_commit_sha}"
+ last_commit_url = (
+ f"{self.git_provider.get_pr_url()}/commits/"
+ f"{self.git_provider.incremental.first_new_commit_sha}"
+ )
incremental_review_markdown_text = f"Starting from commit {last_commit_url}"
- markdown_text = convert_to_markdown_v2(data, self.git_provider.is_supported("gfm_markdown"),
- incremental_review_markdown_text,
- git_provider=self.git_provider,
- files=self.git_provider.get_diff_files())
+ markdown_text = convert_to_markdown_v2(
+ data,
+ self.git_provider.is_supported("gfm_markdown"),
+ incremental_review_markdown_text,
+ git_provider=self.git_provider,
+ files=self.git_provider.get_diff_files(),
+ )
# Add help text if gfm_markdown is supported
- if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_reviewer.enable_help_text:
+ if (
+ self.git_provider.is_supported("gfm_markdown")
+ and get_settings().pr_reviewer.enable_help_text
+ ):
markdown_text += " \n\n 💡 Tool usage guide: \n\n"
markdown_text += HelpMessage.get_review_usage_guide()
markdown_text += "\n \n"
# Output the relevant configurations if enabled
- if get_settings().get('config', {}).get('output_relevant_configurations', False):
- markdown_text += show_relevant_configurations(relevant_section='pr_reviewer')
+ if (
+ get_settings()
+ .get('config', {})
+ .get('output_relevant_configurations', False)
+ ):
+ markdown_text += show_relevant_configurations(
+ relevant_section='pr_reviewer'
+ )
# Add custom labels from the review prediction (effort, security)
self.set_review_labels(data)
@@ -306,34 +383,50 @@ class PRReviewer:
if comment:
self.git_provider.remove_comment(comment)
except Exception as e:
- get_logger().exception(f"Failed to remove previous review comment, error: {e}")
+ get_logger().exception(
+ f"Failed to remove previous review comment, error: {e}"
+ )
def _can_run_incremental_review(self) -> bool:
"""Checks if we can run incremental review according the various configurations and previous review"""
# checking if running is auto mode but there are no new commits
if self.is_auto and not self.incremental.first_new_commit_sha:
- get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new commits")
+ get_logger().info(
+ f"Incremental review is enabled for {self.pr_url} but there are no new commits"
+ )
return False
if not hasattr(self.git_provider, "get_incremental_commits"):
- get_logger().info(f"Incremental review is not supported for {get_settings().config.git_provider}")
+ get_logger().info(
+ f"Incremental review is not supported for {get_settings().config.git_provider}"
+ )
return False
# checking if there are enough commits to start the review
num_new_commits = len(self.incremental.commits_range)
- num_commits_threshold = get_settings().pr_reviewer.minimal_commits_for_incremental_review
+ num_commits_threshold = (
+ get_settings().pr_reviewer.minimal_commits_for_incremental_review
+ )
not_enough_commits = num_new_commits < num_commits_threshold
# checking if the commits are not too recent to start the review
recent_commits_threshold = datetime.datetime.now() - datetime.timedelta(
minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review
)
last_seen_commit_date = (
- self.incremental.last_seen_commit.commit.author.date if self.incremental.last_seen_commit else None
+ self.incremental.last_seen_commit.commit.author.date
+ if self.incremental.last_seen_commit
+ else None
)
all_commits_too_recent = (
- last_seen_commit_date > recent_commits_threshold if self.incremental.last_seen_commit else False
+ last_seen_commit_date > recent_commits_threshold
+ if self.incremental.last_seen_commit
+ else False
)
# check all the thresholds or just one to start the review
- condition = any if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review else all
+ condition = (
+ any
+ if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review
+ else all
+ )
if condition((not_enough_commits, all_commits_too_recent)):
get_logger().info(
f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:"
@@ -348,31 +441,55 @@ class PRReviewer:
return
if not get_settings().pr_reviewer.require_estimate_effort_to_review:
- get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output
+ get_settings().pr_reviewer.enable_review_labels_effort = (
+ False # we did not generate this output
+ )
if not get_settings().pr_reviewer.require_security_review:
- get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output
+ get_settings().pr_reviewer.enable_review_labels_security = (
+ False # we did not generate this output
+ )
- if (get_settings().pr_reviewer.enable_review_labels_security or
- get_settings().pr_reviewer.enable_review_labels_effort):
+ if (
+ get_settings().pr_reviewer.enable_review_labels_security
+ or get_settings().pr_reviewer.enable_review_labels_effort
+ ):
try:
review_labels = []
if get_settings().pr_reviewer.enable_review_labels_effort:
- estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
+ estimated_effort = data['review'][
+ 'estimated_effort_to_review_[1-5]'
+ ]
estimated_effort_number = 0
if isinstance(estimated_effort, str):
try:
- estimated_effort_number = int(estimated_effort.split(',')[0])
+ estimated_effort_number = int(
+ estimated_effort.split(',')[0]
+ )
except ValueError:
- get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}")
+ get_logger().warning(
+ f"Invalid estimated_effort value: {estimated_effort}"
+ )
elif isinstance(estimated_effort, int):
estimated_effort_number = estimated_effort
else:
- get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}")
+ get_logger().warning(
+ f"Unexpected type for estimated_effort: {type(estimated_effort)}"
+ )
if 1 <= estimated_effort_number <= 5: # 1, because ...
- review_labels.append(f'Review effort {estimated_effort_number}/5')
- if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review:
- security_concerns = data['review']['security_concerns'] # yes, because ...
- security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower()
+ review_labels.append(
+ f'Review effort {estimated_effort_number}/5'
+ )
+ if (
+ get_settings().pr_reviewer.enable_review_labels_security
+ and get_settings().pr_reviewer.require_security_review
+ ):
+ security_concerns = data['review'][
+ 'security_concerns'
+ ] # yes, because ...
+ security_concerns_bool = (
+ 'yes' in security_concerns.lower()
+ or 'true' in security_concerns.lower()
+ )
if security_concerns_bool:
review_labels.append('Possible security concern')
@@ -381,17 +498,26 @@ class PRReviewer:
current_labels = []
get_logger().debug(f"Current labels:\n{current_labels}")
if current_labels:
- current_labels_filtered = [label for label in current_labels if
- not label.lower().startswith('review effort') and not label.lower().startswith(
- 'possible security concern')]
+ current_labels_filtered = [
+ label
+ for label in current_labels
+ if not label.lower().startswith('review effort')
+ and not label.lower().startswith('possible security concern')
+ ]
else:
current_labels_filtered = []
new_labels = review_labels + current_labels_filtered
- if (current_labels or review_labels) and sorted(new_labels) != sorted(current_labels):
- get_logger().info(f"Setting review labels:\n{review_labels + current_labels_filtered}")
+ if (current_labels or review_labels) and sorted(new_labels) != sorted(
+ current_labels
+ ):
+ get_logger().info(
+ f"Setting review labels:\n{review_labels + current_labels_filtered}"
+ )
self.git_provider.publish_labels(new_labels)
else:
- get_logger().info(f"Review labels are already set:\n{review_labels + current_labels_filtered}")
+ get_logger().info(
+ f"Review labels are already set:\n{review_labels + current_labels_filtered}"
+ )
except Exception as e:
get_logger().error(f"Failed to set review labels, error: {e}")
@@ -406,5 +532,7 @@ class PRReviewer:
self.git_provider.publish_comment("自动批准 PR")
else:
get_logger().info("Auto-approval option is disabled")
- self.git_provider.publish_comment("PR-Agent 的自动批准选项已禁用. "
- "你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")
+ self.git_provider.publish_comment(
+ "PR-Agent 的自动批准选项已禁用. "
+ "你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)"
+ )
diff --git a/apps/utils/pr_agent/tools/pr_similar_issue.py b/apps/utils/pr_agent/tools/pr_similar_issue.py
index 6f9ea20..a8d2300 100644
--- a/apps/utils/pr_agent/tools/pr_similar_issue.py
+++ b/apps/utils/pr_agent/tools/pr_similar_issue.py
@@ -24,12 +24,16 @@ class PRSimilarIssue:
self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan
self.issue_url = issue_url
self.git_provider = get_git_provider()()
- repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1])
+ repo_name, issue_number = self.git_provider._parse_issue_url(
+ issue_url.split('=')[-1]
+ )
self.git_provider.repo = repo_name
self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
self.token_handler = TokenHandler()
repo_obj = self.git_provider.repo_obj
- repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
+ repo_name_for_index = self.repo_name_for_index = (
+ repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
+ )
index_name = self.index_name = "codium-ai-pr-agent-issues"
if get_settings().pr_similar_issue.vectordb == "pinecone":
@@ -38,17 +42,30 @@ class PRSimilarIssue:
import pinecone
from pinecone_datasets import Dataset, DatasetMetadata
except:
- raise Exception("Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb")
+ raise Exception(
+ "Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb"
+ )
# assuming pinecone api key and environment are set in secrets file
try:
api_key = get_settings().pinecone.api_key
environment = get_settings().pinecone.environment
except Exception:
if not self.cli_mode:
- repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
- issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
- issue_main.create_comment("Please set pinecone api key and environment in secrets file")
- raise Exception("Please set pinecone api key and environment in secrets file")
+ (
+ repo_name,
+ original_issue_number,
+ ) = self.git_provider._parse_issue_url(
+ self.issue_url.split('=')[-1]
+ )
+ issue_main = self.git_provider.repo_obj.get_issue(
+ original_issue_number
+ )
+ issue_main.create_comment(
+ "Please set pinecone api key and environment in secrets file"
+ )
+ raise Exception(
+ "Please set pinecone api key and environment in secrets file"
+ )
# check if index exists, and if repo is already indexed
run_from_scratch = False
@@ -69,7 +86,9 @@ class PRSimilarIssue:
upsert = True
else:
pinecone_index = pinecone.Index(index_name=index_name)
- res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict()
+ res = pinecone_index.fetch(
+ [f"example_issue_{repo_name_for_index}"]
+ ).to_dict()
if res["vectors"]:
upsert = False
@@ -79,7 +98,9 @@ class PRSimilarIssue:
get_logger().info('Getting issues...')
issues = list(repo_obj.get_issues(state='all'))
get_logger().info('Done')
- self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert)
+ self._update_index_with_issues(
+ issues, repo_name_for_index, upsert=upsert
+ )
else: # update index if needed
pinecone_index = pinecone.Index(index_name=index_name)
issues_to_update = []
@@ -105,7 +126,9 @@ class PRSimilarIssue:
if issues_to_update:
get_logger().info(f'Updating index with {counter} new issues...')
- self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True)
+ self._update_index_with_issues(
+ issues_to_update, repo_name_for_index, upsert=True
+ )
else:
get_logger().info('No new issues to update')
@@ -133,7 +156,12 @@ class PRSimilarIssue:
ingest = True
else:
self.table = self.db[index_name]
- res = self.table.search().limit(len(self.table)).where(f"id='example_issue_{repo_name_for_index}'").to_list()
+ res = (
+ self.table.search()
+ .limit(len(self.table))
+ .where(f"id='example_issue_{repo_name_for_index}'")
+ .to_list()
+ )
get_logger().info("result: ", res)
if res[0].get("vector"):
ingest = False
@@ -145,7 +173,9 @@ class PRSimilarIssue:
issues = list(repo_obj.get_issues(state='all'))
get_logger().info('Done')
- self._update_table_with_issues(issues, repo_name_for_index, ingest=ingest)
+ self._update_table_with_issues(
+ issues, repo_name_for_index, ingest=ingest
+ )
else: # update table if needed
issues_to_update = []
issues_paginated_list = repo_obj.get_issues(state='all')
@@ -156,7 +186,12 @@ class PRSimilarIssue:
issue_str, comments, number = self._process_issue(issue)
issue_key = f"issue_{number}"
issue_id = issue_key + "." + "issue"
- res = self.table.search().limit(len(self.table)).where(f"id='{issue_id}'").to_list()
+ res = (
+ self.table.search()
+ .limit(len(self.table))
+ .where(f"id='{issue_id}'")
+ .to_list()
+ )
is_new_issue = True
for r in res:
if r['metadata']['repo'] == repo_name_for_index:
@@ -170,14 +205,17 @@ class PRSimilarIssue:
if issues_to_update:
get_logger().info(f'Updating index with {counter} new issues...')
- self._update_table_with_issues(issues_to_update, repo_name_for_index, ingest=True)
+ self._update_table_with_issues(
+ issues_to_update, repo_name_for_index, ingest=True
+ )
else:
get_logger().info('No new issues to update')
-
async def run(self):
get_logger().info('Getting issue...')
- repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
+ repo_name, original_issue_number = self.git_provider._parse_issue_url(
+ self.issue_url.split('=')[-1]
+ )
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
issue_str, comments, number = self._process_issue(issue_main)
openai.api_key = get_settings().openai.key
@@ -193,10 +231,12 @@ class PRSimilarIssue:
if get_settings().pr_similar_issue.vectordb == "pinecone":
pinecone_index = pinecone.Index(index_name=self.index_name)
- res = pinecone_index.query(embeds[0],
- top_k=5,
- filter={"repo": self.repo_name_for_index},
- include_metadata=True).to_dict()
+ res = pinecone_index.query(
+ embeds[0],
+ top_k=5,
+ filter={"repo": self.repo_name_for_index},
+ include_metadata=True,
+ ).to_dict()
for r in res['matches']:
# skip example issue
@@ -214,14 +254,20 @@ class PRSimilarIssue:
if issue_number not in relevant_issues_number_list:
relevant_issues_number_list.append(issue_number)
if 'comment' in r["id"]:
- relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
+ relevant_comment_number_list.append(
+ int(r["id"].split('.')[1].split('_')[-1])
+ )
else:
relevant_comment_number_list.append(-1)
score_list.append(str("{:.2f}".format(r['score'])))
get_logger().info('Done')
elif get_settings().pr_similar_issue.vectordb == "lancedb":
- res = self.table.search(embeds[0]).where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True).to_list()
+ res = (
+ self.table.search(embeds[0])
+ .where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True)
+ .to_list()
+ )
for r in res:
# skip example issue
@@ -240,10 +286,12 @@ class PRSimilarIssue:
relevant_issues_number_list.append(issue_number)
if 'comment' in r["id"]:
- relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
+ relevant_comment_number_list.append(
+ int(r["id"].split('.')[1].split('_')[-1])
+ )
else:
relevant_comment_number_list.append(-1)
- score_list.append(str("{:.2f}".format(1-r['_distance'])))
+ score_list.append(str("{:.2f}".format(1 - r['_distance'])))
get_logger().info('Done')
get_logger().info('Publishing response...')
@@ -254,8 +302,12 @@ class PRSimilarIssue:
title = issue.title
url = issue.html_url
if relevant_comment_number_list[i] != -1:
- url = list(issue.get_comments())[relevant_comment_number_list[i]].html_url
- similar_issues_str += f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
+ url = list(issue.get_comments())[
+ relevant_comment_number_list[i]
+ ].html_url
+ similar_issues_str += (
+ f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
+ )
if get_settings().config.publish_output:
response = issue_main.create_comment(similar_issues_str)
get_logger().info(similar_issues_str)
@@ -278,7 +330,7 @@ class PRSimilarIssue:
example_issue_record = Record(
id=f"example_issue_{repo_name_for_index}",
text="example_issue",
- metadata=Metadata(repo=repo_name_for_index)
+ metadata=Metadata(repo=repo_name_for_index),
)
corpus.append(example_issue_record)
@@ -298,15 +350,20 @@ class PRSimilarIssue:
issue_key = f"issue_{number}"
username = issue.user.login
created_at = str(issue.created_at)
- if len(issue_str) < 8000 or \
- self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
+ if len(issue_str) < 8000 or self.token_handler.count_tokens(
+ issue_str
+ ) < get_max_tokens(
+ MODEL
+ ): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,
- metadata=Metadata(repo=repo_name_for_index,
- username=username,
- created_at=created_at,
- level=IssueLevel.ISSUE)
+ metadata=Metadata(
+ repo=repo_name_for_index,
+ username=username,
+ created_at=created_at,
+ level=IssueLevel.ISSUE,
+ ),
)
corpus.append(issue_record)
if comments:
@@ -316,15 +373,20 @@ class PRSimilarIssue:
if num_words_comment < 10 or not isinstance(comment_body, str):
continue
- if len(comment_body) < 8000 or \
- self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
+ if (
+ len(comment_body) < 8000
+ or self.token_handler.count_tokens(comment_body)
+ < MAX_TOKENS[MODEL]
+ ):
comment_record = Record(
id=issue_key + ".comment_" + str(j + 1),
text=comment_body,
- metadata=Metadata(repo=repo_name_for_index,
- username=username, # use issue username for all comments
- created_at=created_at,
- level=IssueLevel.COMMENT)
+ metadata=Metadata(
+ repo=repo_name_for_index,
+ username=username, # use issue username for all comments
+ created_at=created_at,
+ level=IssueLevel.COMMENT,
+ ),
)
corpus.append(comment_record)
df = pd.DataFrame(corpus.dict()["documents"])
@@ -355,7 +417,9 @@ class PRSimilarIssue:
environment = get_settings().pinecone.environment
if not upsert:
get_logger().info('Creating index from scratch...')
- ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment)
+ ds.to_pinecone_index(
+ self.index_name, api_key=api_key, environment=environment
+ )
time.sleep(15) # wait for pinecone to finalize indexing before querying
else:
get_logger().info('Upserting index...')
@@ -374,7 +438,7 @@ class PRSimilarIssue:
example_issue_record = Record(
id=f"example_issue_{repo_name_for_index}",
text="example_issue",
- metadata=Metadata(repo=repo_name_for_index)
+ metadata=Metadata(repo=repo_name_for_index),
)
corpus.append(example_issue_record)
@@ -394,15 +458,20 @@ class PRSimilarIssue:
issue_key = f"issue_{number}"
username = issue.user.login
created_at = str(issue.created_at)
- if len(issue_str) < 8000 or \
- self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
+ if len(issue_str) < 8000 or self.token_handler.count_tokens(
+ issue_str
+ ) < get_max_tokens(
+ MODEL
+ ): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,
- metadata=Metadata(repo=repo_name_for_index,
- username=username,
- created_at=created_at,
- level=IssueLevel.ISSUE)
+ metadata=Metadata(
+ repo=repo_name_for_index,
+ username=username,
+ created_at=created_at,
+ level=IssueLevel.ISSUE,
+ ),
)
corpus.append(issue_record)
if comments:
@@ -412,15 +481,20 @@ class PRSimilarIssue:
if num_words_comment < 10 or not isinstance(comment_body, str):
continue
- if len(comment_body) < 8000 or \
- self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
+ if (
+ len(comment_body) < 8000
+ or self.token_handler.count_tokens(comment_body)
+ < MAX_TOKENS[MODEL]
+ ):
comment_record = Record(
id=issue_key + ".comment_" + str(j + 1),
text=comment_body,
- metadata=Metadata(repo=repo_name_for_index,
- username=username, # use issue username for all comments
- created_at=created_at,
- level=IssueLevel.COMMENT)
+ metadata=Metadata(
+ repo=repo_name_for_index,
+ username=username, # use issue username for all comments
+ created_at=created_at,
+ level=IssueLevel.COMMENT,
+ ),
)
corpus.append(comment_record)
df = pd.DataFrame(corpus.dict()["documents"])
@@ -446,7 +520,9 @@ class PRSimilarIssue:
if not ingest:
get_logger().info('Creating table from scratch...')
- self.table = self.db.create_table(self.index_name, data=df, mode="overwrite")
+ self.table = self.db.create_table(
+ self.index_name, data=df, mode="overwrite"
+ )
time.sleep(15)
else:
get_logger().info('Ingesting in Table...')
diff --git a/apps/utils/pr_agent/tools/pr_update_changelog.py b/apps/utils/pr_agent/tools/pr_update_changelog.py
index 56c9eca..f4a7b24 100644
--- a/apps/utils/pr_agent/tools/pr_update_changelog.py
+++ b/apps/utils/pr_agent/tools/pr_update_changelog.py
@@ -20,13 +20,20 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog:
- def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
-
+ def __init__(
+ self,
+ pr_url: str,
+ cli_mode=False,
+ args=None,
+ ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
+ ):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
- self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
+ self.commit_changelog = (
+ get_settings().pr_update_changelog.push_changelog_changes
+ )
self._get_changelog_file() # self.changelog_file_str
self.ai_handler = ai_handler()
@@ -47,15 +54,19 @@ class PRUpdateChangelog:
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
- self.token_handler = TokenHandler(self.git_provider.pr,
- self.vars,
- get_settings().pr_update_changelog_prompt.system,
- get_settings().pr_update_changelog_prompt.user)
+ self.token_handler = TokenHandler(
+ self.git_provider.pr,
+ self.vars,
+ get_settings().pr_update_changelog_prompt.system,
+ get_settings().pr_update_changelog_prompt.user,
+ )
async def run(self):
get_logger().info('Updating the changelog...')
- relevant_configs = {'pr_update_changelog': dict(get_settings().pr_update_changelog),
- 'config': dict(get_settings().config)}
+ relevant_configs = {
+ 'pr_update_changelog': dict(get_settings().pr_update_changelog),
+ 'config': dict(get_settings().config),
+ }
get_logger().debug("Relevant configs", artifacts=relevant_configs)
# currently only GitHub is supported for pushing changelog changes
@@ -74,13 +85,21 @@ class PRUpdateChangelog:
if get_settings().config.publish_output:
self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True)
- await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
+ await retry_with_fallback_models(
+ self._prepare_prediction, model_type=ModelType.WEAK
+ )
new_file_content, answer = self._prepare_changelog_update()
# Output the relevant configurations if enabled
- if get_settings().get('config', {}).get('output_relevant_configurations', False):
- answer += show_relevant_configurations(relevant_section='pr_update_changelog')
+ if (
+ get_settings()
+ .get('config', {})
+ .get('output_relevant_configurations', False)
+ ):
+ answer += show_relevant_configurations(
+ relevant_section='pr_update_changelog'
+ )
get_logger().debug(f"PR output", artifact=answer)
@@ -89,7 +108,9 @@ class PRUpdateChangelog:
if self.commit_changelog:
self._push_changelog_update(new_file_content, answer)
else:
- self.git_provider.publish_comment(f"**Changelog updates:** 🔄\n\n{answer}")
+ self.git_provider.publish_comment(
+ f"**Changelog updates:** 🔄\n\n{answer}"
+ )
async def _prepare_prediction(self, model: str):
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
@@ -106,10 +127,18 @@ class PRUpdateChangelog:
if get_settings().pr_update_changelog.add_pr_link:
variables["pr_link"] = self.git_provider.get_pr_url()
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
- user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
+ system_prompt = environment.from_string(
+ get_settings().pr_update_changelog_prompt.system
+ ).render(variables)
+ user_prompt = environment.from_string(
+ get_settings().pr_update_changelog_prompt.user
+ ).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
- model=model, system=system_prompt, user=user_prompt, temperature=get_settings().config.temperature)
+ model=model,
+ system=system_prompt,
+ user=user_prompt,
+ temperature=get_settings().config.temperature,
+ )
# post-process the response
response = response.strip()
@@ -134,8 +163,10 @@ class PRUpdateChangelog:
new_file_content = answer
if not self.commit_changelog:
- answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
- "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
+ answer += (
+ "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:"
+ "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
+ )
return new_file_content, answer
@@ -163,8 +194,7 @@ class PRUpdateChangelog:
self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}")
def _get_default_changelog(self):
- example_changelog = \
-"""
+ example_changelog = """
Example:
##
diff --git a/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py b/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py
index 387f428..efc7d3a 100644
--- a/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py
+++ b/apps/utils/pr_agent/tools/ticket_pr_compliance_check.py
@@ -7,14 +7,15 @@ from utils.pr_agent.log import get_logger
# Compile the regex pattern once, outside the function
GITHUB_TICKET_PATTERN = re.compile(
- r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
+ r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
)
+
def find_jira_tickets(text):
# Regular expression patterns for JIRA tickets
patterns = [
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
- r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket
+ r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b', # JIRA URL or just the ticket
]
tickets = set()
@@ -32,7 +33,9 @@ def find_jira_tickets(text):
return list(tickets)
-def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url_html='https://github.com'):
+def extract_ticket_links_from_pr_description(
+ pr_description, repo_path, base_url_html='https://github.com'
+):
"""
Extract all ticket links from PR description
"""
@@ -46,19 +49,27 @@ def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url
github_tickets.add(match[0])
elif match[1]: # Shorthand notation match: owner/repo#issue_number
owner, repo, issue_number = match[2], match[3], match[4]
- github_tickets.add(f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}')
+ github_tickets.add(
+ f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}'
+ )
else: # #123 format
issue_number = match[5][1:] # remove #
if issue_number.isdigit() and len(issue_number) < 5 and repo_path:
- github_tickets.add(f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}')
+ github_tickets.add(
+ f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}'
+ )
if len(github_tickets) > 3:
- get_logger().info(f"Too many tickets found in PR description: {len(github_tickets)}")
+ get_logger().info(
+ f"Too many tickets found in PR description: {len(github_tickets)}"
+ )
# Limit the number of tickets to 3
github_tickets = set(list(github_tickets)[:3])
except Exception as e:
- get_logger().error(f"Error extracting tickets error= {e}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Error extracting tickets error= {e}",
+ artifact={"traceback": traceback.format_exc()},
+ )
return list(github_tickets)
@@ -68,19 +79,26 @@ async def extract_tickets(git_provider):
try:
if isinstance(git_provider, GithubProvider):
user_description = git_provider.get_user_description()
- tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html)
+ tickets = extract_ticket_links_from_pr_description(
+ user_description, git_provider.repo, git_provider.base_url_html
+ )
tickets_content = []
if tickets:
-
for ticket in tickets:
- repo_name, original_issue_number = git_provider._parse_issue_url(ticket)
+ repo_name, original_issue_number = git_provider._parse_issue_url(
+ ticket
+ )
try:
- issue_main = git_provider.repo_obj.get_issue(original_issue_number)
+ issue_main = git_provider.repo_obj.get_issue(
+ original_issue_number
+ )
except Exception as e:
- get_logger().error(f"Error getting main issue: {e}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Error getting main issue: {e}",
+ artifact={"traceback": traceback.format_exc()},
+ )
continue
issue_body_str = issue_main.body or ""
@@ -93,47 +111,66 @@ async def extract_tickets(git_provider):
sub_issues = git_provider.fetch_sub_issues(ticket)
for sub_issue_url in sub_issues:
try:
- sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url)
- sub_issue = git_provider.repo_obj.get_issue(sub_issue_number)
+ (
+ sub_repo,
+ sub_issue_number,
+ ) = git_provider._parse_issue_url(sub_issue_url)
+ sub_issue = git_provider.repo_obj.get_issue(
+ sub_issue_number
+ )
sub_body = sub_issue.body or ""
if len(sub_body) > MAX_TICKET_CHARACTERS:
sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..."
- sub_issues_content.append({
- 'ticket_url': sub_issue_url,
- 'title': sub_issue.title,
- 'body': sub_body
- })
+ sub_issues_content.append(
+ {
+ 'ticket_url': sub_issue_url,
+ 'title': sub_issue.title,
+ 'body': sub_body,
+ }
+ )
except Exception as e:
- get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}")
+ get_logger().warning(
+ f"Failed to fetch sub-issue content for {sub_issue_url}: {e}"
+ )
except Exception as e:
- get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}")
+ get_logger().warning(
+ f"Failed to fetch sub-issues for {ticket}: {e}"
+ )
# Extract labels
labels = []
try:
for label in issue_main.labels:
- labels.append(label.name if hasattr(label, 'name') else label)
+ labels.append(
+ label.name if hasattr(label, 'name') else label
+ )
except Exception as e:
- get_logger().error(f"Error extracting labels error= {e}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Error extracting labels error= {e}",
+ artifact={"traceback": traceback.format_exc()},
+ )
- tickets_content.append({
- 'ticket_id': issue_main.number,
- 'ticket_url': ticket,
- 'title': issue_main.title,
- 'body': issue_body_str,
- 'labels': ", ".join(labels),
- 'sub_issues': sub_issues_content # Store sub-issues content
- })
+ tickets_content.append(
+ {
+ 'ticket_id': issue_main.number,
+ 'ticket_url': ticket,
+ 'title': issue_main.title,
+ 'body': issue_body_str,
+ 'labels': ", ".join(labels),
+ 'sub_issues': sub_issues_content, # Store sub-issues content
+ }
+ )
return tickets_content
except Exception as e:
- get_logger().error(f"Error extracting tickets error= {e}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Error extracting tickets error= {e}",
+ artifact={"traceback": traceback.format_exc()},
+ )
async def extract_and_cache_pr_tickets(git_provider, vars):
@@ -154,8 +191,10 @@ async def extract_and_cache_pr_tickets(git_provider, vars):
related_tickets.append(ticket)
- get_logger().info("Extracted tickets and sub-issues from PR description",
- artifact={"tickets": related_tickets})
+ get_logger().info(
+ "Extracted tickets and sub-issues from PR description",
+ artifact={"tickets": related_tickets},
+ )
vars['related_tickets'] = related_tickets
get_settings().set('related_tickets', related_tickets)
diff --git a/config.ini b/config.ini
new file mode 100644
index 0000000..4c277de
--- /dev/null
+++ b/config.ini
@@ -0,0 +1,13 @@
+[BASE]
+; 是否开启debug模式 0或1
+DEBUG = 0
+
+[DATABASE]
+; 默认采用sqlite,线上需替换为pg
+DEFAULT = pg
+; postgres配置
+DB_NAME = pr_manager
+DB_USER = admin
+DB_PASSWORD = admin123456
+DB_HOST = 110.40.30.95
+DB_PORT = 5432
diff --git a/pr_manager/settings.py b/pr_manager/settings.py
index d2a19e0..a4423ff 100644
--- a/pr_manager/settings.py
+++ b/pr_manager/settings.py
@@ -12,11 +12,18 @@ https://docs.djangoproject.com/en/5.1/ref/settings/
import os
import sys
+import configparser
from pathlib import Path
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
+CONFIG_NAME = BASE_DIR / "config.ini"
+
+# 加载配置文件: 开发可加载config.local.ini
+_config = configparser.ConfigParser()
+_config.read(CONFIG_NAME, encoding="utf-8")
+
sys.path.insert(0, os.path.join(BASE_DIR, "apps"))
sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
@@ -27,7 +34,7 @@ sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76"
# SECURITY WARNING: don't run with debug turned on in production!
-DEBUG = False
+DEBUG = bool(int(_config["BASE"].get("DEBUG", "1")))
ALLOWED_HOSTS = ["*"]
@@ -44,7 +51,7 @@ INSTALLED_APPS = [
"django.contrib.messages",
"django.contrib.staticfiles",
"public",
- "pr"
+ "pr",
]
# 配置安全秘钥
@@ -68,8 +75,7 @@ ROOT_URLCONF = "pr_manager.urls"
TEMPLATES = [
{
"BACKEND": "django.template.backends.django.DjangoTemplates",
- "DIRS": [BASE_DIR / 'templates']
- ,
+ "DIRS": [BASE_DIR / 'templates'],
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
@@ -89,12 +95,22 @@ WSGI_APPLICATION = "pr_manager.wsgi.application"
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases
DATABASES = {
- "default": {
+ "pg": {
+ "ENGINE": "django.db.backends.postgresql",
+ "NAME": _config["DATABASE"].get("DB_NAME", "chat_ai_v2"),
+ "USER": _config["DATABASE"].get("DB_USER", "admin"),
+ "PASSWORD": _config["DATABASE"].get("DB_PASSWORD", "admin123456"),
+ "HOST": _config["DATABASE"].get("DB_HOST", "124.222.222.101"),
+ "PORT": int(_config["DATABASE"].get("DB_PORT", "5432")),
+ },
+ "sqlite": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": BASE_DIR / "db.sqlite3",
- }
+ },
}
+DATABASES["default"] = DATABASES[_config["DATABASE"].get("DEFAULT", "sqlite")]
+
# Password validation
# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators
| |