first commit

This commit is contained in:
张建平 2025-02-25 14:29:18 +08:00
commit 578f5d9a3d
118 changed files with 20358 additions and 0 deletions

15
.gitignore vendored Normal file
View File

@ -0,0 +1,15 @@
.idea/
.lsp/
.vscode/
.env
venv/
pr_agent/settings/.secrets.toml
__pycache__
dist/
*.egg-info/
build/
.DS_Store
docs/.cache/
.qodo
db.sqlite3
#pr_agent/

21
Dockerfile Normal file
View File

@ -0,0 +1,21 @@
FROM python:3.12-slim
ENV PYTHONUNBUFFERED 1
ENV TZ=Asia/Shanghai
WORKDIR /app
COPY . /app
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list
RUN apt-get update \
&& apt-get install -y procps net-tools apt-utils \
&& ln -snf /usr/share/zoneinfo/${TZ} /etc/localtime && echo ${TZ} > /etc/timezone \
&& pip install pipenv -i https://pypi.tuna.tsinghua.edu.cn/simple/
RUN pipenv sync && pipenv install --dev
RUN chmod +x /app/start.sh
CMD ["sh", "start.sh"]

28
Pipfile Normal file
View File

@ -0,0 +1,28 @@
[[source]]
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
verify_ssl = true
name = "pip_conf_index_global"
[packages]
django = "*"
simplepro = "*"
django-import-export = "*"
litellm = "*"
tenacity = "*"
html2text = "*"
starlette-context = "*"
dynaconf = "*"
loguru = "*"
atlassian-python-api = "*"
boto3 = "*"
gitpython = "*"
pygithub = "*"
python-gitlab = "*"
retry = "*"
fastapi = "*"
[dev-packages]
[requires]
python_version = "3.12"
python_full_version = "3.12.1"

2153
Pipfile.lock generated Normal file

File diff suppressed because it is too large Load Diff

0
apps/__init__.py Normal file
View File

0
apps/pr/__init__.py Normal file
View File

40
apps/pr/admin.py Normal file
View File

@ -0,0 +1,40 @@
from django.contrib import admin
from simpleui.admin import AjaxAdmin
from pr import models
@admin.register(models.AIConfig)
class AIConfigAdmin(AjaxAdmin):
"""Admin配置"""
list_display = ["api_base", "api_key", "llm_model"]
top_html = ' <el-alert title="可配置多个AI模型厂商!" type="success"></el-alert>'
def save_model(self, request, obj, form, change):
obj.create_by = request.user.username
return super().save_model(request, obj, form, change)
@admin.register(models.GitConfig)
class GitConfigAdmin(AjaxAdmin):
"""Admin配置"""
list_display = ["git_name", "git_type", "git_url", "access_token"]
top_html = '<el-alert title="可配置多个Git服务上!" type="success"></el-alert>'
def save_model(self, request, obj, form, change):
obj.create_by = request.user.username
return super().save_model(request, obj, form, change)
@admin.register(models.ProjectConfig)
class ProjectConfigAdmin(AjaxAdmin):
"""Admin配置"""
list_display = ["project_id", "project_name", "project_secret", "commands", "is_enable"]
top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>'
def save_model(self, request, obj, form, change):
obj.create_by = request.user.username
return super().save_model(request, obj, form, change)

7
apps/pr/apps.py Normal file
View File

@ -0,0 +1,7 @@
from django.apps import AppConfig
class PrConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "pr"
verbose_name = "PR管理配置"

View File

View File

View File

@ -0,0 +1,19 @@
from django.core.management.base import BaseCommand
from pr import models
class Command(BaseCommand):
help = "数据初始化"
def handle(self, *args, **options):
ai_config, created = models.AIConfig.objects.get_or_create(
api_base="http://110.40.24.85:3000/v1",
api_key="sk-YLeQEboTsCEzfbmhbnytWRPyuC8Swe7OsBRKH30X26Jf1fsm",
llm_model="o3-mini",
)
if created:
print("初始化AI配置已创建")
else:
print("初始化AI配置已存在")

View File

@ -0,0 +1,260 @@
# Generated by Django 5.1.6 on 2025-02-25 13:55
import django.db.models.deletion
import simplepro.components.fields
import uuid
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = []
operations = [
migrations.CreateModel(
name="AIConfig",
fields=[
("id", models.BigAutoField(primary_key=True, serialize=False)),
(
"uid",
models.UUIDField(
db_index=True,
default=uuid.uuid4,
editable=False,
verbose_name="UUID",
),
),
(
"create_at",
simplepro.components.fields.DateTimeField(
auto_now_add=True, db_index=True, verbose_name="创建时间"
),
),
(
"update_at",
simplepro.components.fields.DateTimeField(
auto_now=True, verbose_name="更新时间"
),
),
(
"delete_at",
simplepro.components.fields.DateTimeField(
blank=True, null=True, verbose_name="删除时间"
),
),
(
"create_by",
simplepro.components.fields.CharField(
blank=True, max_length=32, null=True, verbose_name="创建人"
),
),
(
"detail",
simplepro.components.fields.CharField(
blank=True, max_length=200, null=True, verbose_name="备注信息"
),
),
(
"api_base",
simplepro.components.fields.CharField(
blank=True, max_length=128, null=True, verbose_name="API(代理)地址"
),
),
(
"api_key",
simplepro.components.fields.CharField(
blank=True, max_length=128, null=True, verbose_name="API密钥"
),
),
(
"llm_model",
simplepro.components.fields.CharField(
blank=True, max_length=16, null=True, verbose_name="LLM模型"
),
),
],
options={
"verbose_name": "AI模型配置",
"verbose_name_plural": "AI模型配置",
},
),
migrations.CreateModel(
name="GitConfig",
fields=[
("id", models.BigAutoField(primary_key=True, serialize=False)),
(
"uid",
models.UUIDField(
db_index=True,
default=uuid.uuid4,
editable=False,
verbose_name="UUID",
),
),
(
"create_at",
simplepro.components.fields.DateTimeField(
auto_now_add=True, db_index=True, verbose_name="创建时间"
),
),
(
"update_at",
simplepro.components.fields.DateTimeField(
auto_now=True, verbose_name="更新时间"
),
),
(
"delete_at",
simplepro.components.fields.DateTimeField(
blank=True, null=True, verbose_name="删除时间"
),
),
(
"create_by",
simplepro.components.fields.CharField(
blank=True, max_length=32, null=True, verbose_name="创建人"
),
),
(
"detail",
simplepro.components.fields.CharField(
blank=True, max_length=200, null=True, verbose_name="备注信息"
),
),
(
"git_name",
simplepro.components.fields.CharField(
blank=True, max_length=16, null=True, verbose_name="Git名称"
),
),
(
"git_type",
simplepro.components.fields.RadioField(
choices=[
("gitlab", "gitlab"),
("github", "github"),
("gitea", "gitea"),
],
default="gitlab",
verbose_name="Git类型",
),
),
(
"git_url",
simplepro.components.fields.CharField(
blank=True, max_length=128, null=True, verbose_name="Git地址"
),
),
(
"access_token",
simplepro.components.fields.CharField(
blank=True, max_length=128, null=True, verbose_name="访问密钥"
),
),
(
"pr_ai",
simplepro.components.fields.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="pr.aiconfig",
verbose_name="AI模型",
),
),
],
options={
"verbose_name": "Git服务配置",
"verbose_name_plural": "Git服务配置",
},
),
migrations.CreateModel(
name="ProjectConfig",
fields=[
("id", models.BigAutoField(primary_key=True, serialize=False)),
(
"uid",
models.UUIDField(
db_index=True,
default=uuid.uuid4,
editable=False,
verbose_name="UUID",
),
),
(
"create_at",
simplepro.components.fields.DateTimeField(
auto_now_add=True, db_index=True, verbose_name="创建时间"
),
),
(
"update_at",
simplepro.components.fields.DateTimeField(
auto_now=True, verbose_name="更新时间"
),
),
(
"delete_at",
simplepro.components.fields.DateTimeField(
blank=True, null=True, verbose_name="删除时间"
),
),
(
"create_by",
simplepro.components.fields.CharField(
blank=True, max_length=32, null=True, verbose_name="创建人"
),
),
(
"detail",
simplepro.components.fields.CharField(
blank=True, max_length=200, null=True, verbose_name="备注信息"
),
),
(
"project_id",
simplepro.components.fields.CharField(
blank=True, max_length=8, null=True, verbose_name="项目ID"
),
),
(
"project_name",
simplepro.components.fields.CharField(
blank=True, max_length=16, null=True, verbose_name="项目名称"
),
),
(
"project_secret",
simplepro.components.fields.CharField(
blank=True, max_length=128, null=True, verbose_name="项目密钥"
),
),
(
"commands",
simplepro.components.fields.CheckboxField(
default=["/review"], max_length=256, verbose_name="默认命令"
),
),
(
"is_enable",
simplepro.components.fields.SwitchField(
default=True, verbose_name="是否启用"
),
),
(
"git_config",
simplepro.components.fields.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="pr.gitconfig",
verbose_name="Git配置",
),
),
],
options={
"verbose_name": "项目配置",
"verbose_name_plural": "项目配置",
},
),
]

View File

94
apps/pr/models.py Normal file
View File

@ -0,0 +1,94 @@
from django.db import models
from simplepro.components import fields
from public.models import BaseModel
from utils import constant
class AIConfig(BaseModel):
"""
AI模型配置表
"""
api_base = fields.CharField(
null=True, blank=True, max_length=128, verbose_name="API(代理)地址"
)
api_key = fields.CharField(
null=True, blank=True, max_length=128, verbose_name="API密钥"
)
llm_model = fields.CharField(
null=True, blank=True, max_length=16, verbose_name="LLM模型"
)
class Meta:
verbose_name = "AI模型配置"
verbose_name_plural = "AI模型配置"
class GitConfig(BaseModel):
"""
Git服务配置表
"""
pr_ai = fields.ForeignKey(
AIConfig,
null=True,
blank=True,
on_delete=models.SET_NULL,
verbose_name="AI模型",
)
git_name = fields.CharField(
null=True, blank=True, max_length=16, verbose_name="Git名称"
)
git_type = fields.RadioField(
choices=constant.GIT_TYPE,
default="gitlab",
verbose_name="Git类型"
)
git_url = fields.CharField(
null=True, blank=True, max_length=128, verbose_name="Git地址"
)
access_token = fields.CharField(
null=True, blank=True, max_length=128, verbose_name="访问密钥"
)
class Meta:
verbose_name = "Git服务配置"
verbose_name_plural = "Git服务配置"
class ProjectConfig(BaseModel):
"""
项目配置表
"""
git_config = fields.ForeignKey(
GitConfig,
null=True,
blank=True,
on_delete=models.SET_NULL,
verbose_name="Git配置",
)
project_id = fields.CharField(
null=True, blank=True, max_length=8, verbose_name="项目ID"
)
project_name = fields.CharField(
null=True, blank=True, max_length=16, verbose_name="项目名称"
)
project_secret = fields.CharField(
null=True, blank=True, max_length=128, verbose_name="项目密钥"
)
commands = fields.CheckboxField(
choices=constant.DEFAULT_COMMANDS,
default=["/review"],
max_length=256,
verbose_name="默认命令",
)
is_enable = fields.SwitchField(
default=True,
verbose_name="是否启用"
)
class Meta:
verbose_name = "项目配置"
verbose_name_plural = "项目配置"

3
apps/pr/tests.py Normal file
View File

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

24
apps/pr/urls.py Normal file
View File

@ -0,0 +1,24 @@
"""
URL configuration for pr_manager project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/5.1/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.urls import path
from pr.views import WebHookView
urlpatterns = [
path("webhook/", WebHookView.as_view()),
]

113
apps/pr/views.py Normal file
View File

@ -0,0 +1,113 @@
from pr import models
from django.views import View
from django.http import JsonResponse
from utils.pr_agent import cli
from utils.pr_agent.config_loader import get_settings
from utils import constant
def load_project_config(
git_url,
access_token,
project_secret,
openai_api_base,
openai_key,
llm_model
):
"""
加载项目配置
:param git_url: git服务器地址
:param access_token: 用户访问密钥
:param project_secret: 项目秘钥
:param openai_api_base: openai api base
:param openai_key: openai key
:param llm_model: llm model
:return:
"""
return {
"gitlab_url": git_url,
"access_token": access_token,
"secret": project_secret,
"openai_api_base": openai_api_base,
"openai_key": openai_key,
"llm_model": llm_model
}
class WebHookView(View):
def post(self, request):
data = request.POST
if not data:
return JsonResponse(status=400, data={"error": "Invalid JSON"})
project_id = data.get('project', {}).get('id') or data.get('project_id')
if not project_id:
return JsonResponse(status=400, data={"error": "Missing project ID"})
project_config = models.ProjectConfig.objects.filter(project_id=project_id).first()
# AI模型配置
api_base = project_config.git_config.pr_ai.api_base
api_key = project_config.git_config.pr_ai.api_key
model = project_config.git_config.pr_ai.llm_model
# Git服务器配置
git_url = project_config.git_config.git_url
git_type = project_config.git_config.git_type
access_token = project_config.git_config.access_token
project_secret = project_config.project_secret
project_commands = project_config.commands
config = load_project_config(
git_url=git_url,
access_token=access_token,
project_secret=project_secret,
openai_api_base=api_base,
openai_key=api_key,
llm_model=model
)
token = request.headers.get('X-Gitlab-Token')
if token:
token = token.strip()
expected_token = config["secret"].strip() if config["secret"] else None
if token != expected_token:
return JsonResponse(status=403, data={"error": "Invalid token"})
# 处理Merge Request事件
if data.get('object_kind') == 'merge_request':
merge_request = data.get('object_attributes', {})
if merge_request.get('state') == 'opened':
# 获取Merge Request的详细信息
mr_url = merge_request.get('url')
mr_action = merge_request.get('action')
get_settings().set("config.git_provider", git_type)
get_settings().set("gitlab.url", git_url)
get_settings().set("gitlab.personal_access_token", access_token)
get_settings().set("openai.api_base", api_base)
get_settings().set("openai.key", api_key)
get_settings().set("llm.model", model)
if mr_action == "update":
old_rev = merge_request.get("oldrev")
new_rev = merge_request.get("newrev")
if old_rev == new_rev:
return JsonResponse(status=200, data={"status": "ignored (no code change)"})
import threading
def run_cmd(command):
cli.run_command(mr_url, command)
threads = []
for cmd in project_commands:
if cmd not in constant.DEFAULT_COMMANDS:
continue
t = threading.Thread(target=run_cmd, args=(cmd,))
threads.append(t)
t.start()
# 记录MR信息
return JsonResponse(status=200, data={"status": "review started"})
return JsonResponse(status=400, data={"error": "Merge request URL not found or action not open"})
return JsonResponse(status=200, data={"status": "ignored"})

0
apps/public/__init__.py Normal file
View File

3
apps/public/admin.py Normal file
View File

@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

6
apps/public/apps.py Normal file
View File

@ -0,0 +1,6 @@
from django.apps import AppConfig
class PublicConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "public"

View File

64
apps/public/models.py Normal file
View File

@ -0,0 +1,64 @@
import uuid
from datetime import datetime
from django.db import models
from simplepro.components import fields
class BaseQuerySet(models.QuerySet):
def set_delete(self):
return super(BaseQuerySet, self).update(delete_at=datetime.now())
class BaseManager(models.Manager):
def __init__(self, *args, **kwargs):
self.alive_only = kwargs.pop("alive_only", True)
super(BaseManager, self).__init__(*args, **kwargs)
def get_queryset(self):
"""软删除"""
if self.alive_only:
# 为True表示返回给admin的queryset为已过滤数据
return BaseQuerySet(self.model).filter(delete_at__isnull=True)
return BaseQuerySet(self.model)
class BaseModel(models.Model):
"""
自定义Model基类
"""
id = models.BigAutoField(primary_key=True)
uid = models.UUIDField(
default=uuid.uuid4, editable=False, db_index=True, verbose_name="UUID"
)
create_at = fields.DateTimeField(
auto_now_add=True, db_index=True, verbose_name="创建时间"
)
update_at = fields.DateTimeField(auto_now=True, verbose_name="更新时间")
delete_at = fields.DateTimeField(null=True, blank=True, verbose_name="删除时间")
create_by = fields.CharField(
null=True, blank=True, max_length=32, verbose_name="创建人"
)
detail = fields.CharField(
null=True,
blank=True,
max_length=200,
show_word_limit=True,
prefix_icon="el-icon-edit",
verbose_name="备注信息",
placeholder="请输入备注信息(可为空)",
)
objects = BaseManager() # 默认查看已存在数据
all_objects = BaseManager(alive_only=False) # 返回已存在数据(包括已删除)
class Meta:
abstract = True
ordering = ["-create_at"]
def set_delete(self):
"""软删除"""
self.delete_at = datetime.now()
self.save()

3
apps/public/tests.py Normal file
View File

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

3
apps/public/views.py Normal file
View File

@ -0,0 +1,3 @@
from django.shortcuts import render
# Create your views here.

0
apps/utils/__init__.py Normal file
View File

11
apps/utils/constant.py Normal file
View File

@ -0,0 +1,11 @@
GIT_TYPE = (
("gitlab", "gitlab"),
("github", "github"),
("gitea", "gitea")
)
DEFAULT_COMMANDS = [
"/review",
"/describe",
"/improve_code"
]

View File

View File

View File

@ -0,0 +1,93 @@
import shlex
from functools import partial
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.cli_args import CliArgs
from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.log import get_logger
from utils.pr_agent.tools.pr_add_docs import PRAddDocs
from utils.pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from utils.pr_agent.tools.pr_config import PRConfig
from utils.pr_agent.tools.pr_description import PRDescription
from utils.pr_agent.tools.pr_generate_labels import PRGenerateLabels
from utils.pr_agent.tools.pr_help_message import PRHelpMessage
from utils.pr_agent.tools.pr_line_questions import PR_LineQuestions
from utils.pr_agent.tools.pr_questions import PRQuestions
from utils.pr_agent.tools.pr_reviewer import PRReviewer
from utils.pr_agent.tools.pr_similar_issue import PRSimilarIssue
from utils.pr_agent.tools.pr_update_changelog import PRUpdateChangelog
command2class = {
"auto_review": PRReviewer,
"answer": PRReviewer,
"review": PRReviewer,
"review_pr": PRReviewer,
"describe": PRDescription,
"describe_pr": PRDescription,
"improve": PRCodeSuggestions,
"improve_code": PRCodeSuggestions,
"ask": PRQuestions,
"ask_question": PRQuestions,
"ask_line": PR_LineQuestions,
"update_changelog": PRUpdateChangelog,
"config": PRConfig,
"settings": PRConfig,
"help": PRHelpMessage,
"similar_issue": PRSimilarIssue,
"add_docs": PRAddDocs,
"generate_labels": PRGenerateLabels,
}
commands = list(command2class.keys())
class PRAgent:
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler # will be initialized in run_action
async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists
apply_repo_settings(pr_url)
# Then, apply user specific settings if exists
if isinstance(request, str):
request = request.replace("'", "\\'")
lexer = shlex.shlex(request, posix=True)
lexer.whitespace_split = True
action, *args = list(lexer)
else:
action, *args = request
# validate args
is_valid, arg = CliArgs.validate_user_args(args)
if not is_valid:
get_logger().error(
f"CLI argument for param '{arg}' is forbidden. Use instead a configuration file."
)
return False
# Update settings from args
args = update_settings_from_args(args)
action = action.lstrip("/").lower()
if action not in command2class:
get_logger().error(f"Unknown command: {action}")
return False
with get_logger().contextualize(command=action, pr_url=pr_url):
get_logger().info("PR-Agent request handler started", analytics=True)
if action == "answer":
if notify:
notify()
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
elif action == "auto_review":
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
elif action in command2class:
if notify:
notify()
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
else:
return False
return True

View File

@ -0,0 +1,103 @@
MAX_TOKENS = {
'text-embedding-ada-002': 8000,
'gpt-3.5-turbo': 16000,
'gpt-3.5-turbo-0125': 16000,
'gpt-3.5-turbo-0613': 4000,
'gpt-3.5-turbo-1106': 16000,
'gpt-3.5-turbo-16k': 16000,
'gpt-3.5-turbo-16k-0613': 16000,
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o-2024-05-13': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-turbo-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-turbo-2024-04-09': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-turbo': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o-mini': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o-mini-2024-07-18': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o-2024-08-06': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o-2024-11-20': 128000, # 128K, but may be limited by config.max_model_tokens
'o1-mini': 128000, # 128K, but may be limited by config.max_model_tokens
'o1-mini-2024-09-12': 128000, # 128K, but may be limited by config.max_model_tokens
'o1-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'o1-preview-2024-09-12': 128000, # 128K, but may be limited by config.max_model_tokens
'o1-2024-12-17': 204800, # 200K, but may be limited by config.max_model_tokens
'o1': 204800, # 200K, but may be limited by config.max_model_tokens
'o3-mini': 204800, # 200K, but may be limited by config.max_model_tokens
'o3-mini-2025-01-31': 204800, # 200K, but may be limited by config.max_model_tokens
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
'deepseek/deepseek-chat': 128000, # 128K, but may be limited by config.max_model_tokens
'deepseek/deepseek-reasoner': 64000, # 64K, but may be limited by config.max_model_tokens
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
'meta-llama/Llama-2-7b-chat-hf': 4096,
'vertex_ai/codechat-bison': 6144,
'vertex_ai/codechat-bison-32k': 32000,
'vertex_ai/claude-3-haiku@20240307': 100000,
'vertex_ai/claude-3-5-haiku@20241022': 100000,
'vertex_ai/claude-3-sonnet@20240229': 100000,
'vertex_ai/claude-3-opus@20240229': 100000,
'vertex_ai/claude-3-5-sonnet@20240620': 100000,
'vertex_ai/claude-3-5-sonnet-v2@20241022': 100000,
'vertex_ai/gemini-1.5-pro': 1048576,
'vertex_ai/gemini-1.5-flash': 1048576,
'vertex_ai/gemini-2.0-flash': 1048576,
'vertex_ai/gemma2': 8200,
'gemini/gemini-1.5-pro': 1048576,
'gemini/gemini-1.5-flash': 1048576,
'gemini/gemini-2.0-flash': 1048576,
'codechat-bison': 6144,
'codechat-bison-32k': 32000,
'anthropic.claude-instant-v1': 100000,
'anthropic.claude-v1': 100000,
'anthropic.claude-v2': 100000,
'anthropic/claude-3-opus-20240229': 100000,
'anthropic/claude-3-5-sonnet-20240620': 100000,
'anthropic/claude-3-5-sonnet-20241022': 100000,
'anthropic/claude-3-5-haiku-20241022': 100000,
'bedrock/anthropic.claude-instant-v1': 100000,
'bedrock/anthropic.claude-v2': 100000,
'bedrock/anthropic.claude-v2:1': 100000,
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000,
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
'bedrock/anthropic.claude-3-5-haiku-20241022-v1:0': 100000,
'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000,
'bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0': 100000,
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0": 100000,
'claude-3-5-sonnet': 100000,
'groq/llama3-8b-8192': 8192,
'groq/llama3-70b-8192': 8192,
'groq/llama-3.1-8b-instant': 8192,
'groq/llama-3.3-70b-versatile': 128000,
'groq/mixtral-8x7b-32768': 32768,
'groq/gemma2-9b-it': 8192,
'ollama/llama3': 4096,
'watsonx/meta-llama/llama-3-8b-instruct': 4096,
"watsonx/meta-llama/llama-3-70b-instruct": 4096,
"watsonx/meta-llama/llama-3-405b-instruct": 16384,
"watsonx/ibm/granite-13b-chat-v2": 8191,
"watsonx/ibm/granite-34b-code-instruct": 8191,
"watsonx/mistralai/mistral-large": 32768,
}
USER_MESSAGE_ONLY_MODELS = [
"deepseek/deepseek-reasoner",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview"
]
NO_SUPPORT_TEMPERATURE_MODELS = [
"deepseek/deepseek-reasoner",
"o1-mini",
"o1-mini-2024-09-12",
"o1",
"o1-2024-12-17",
"o3-mini",
"o3-mini-2025-01-31",
"o1-preview"
]

View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler to be used by the PR Agents.
"""
@abstractmethod
def __init__(self):
pass
@property
@abstractmethod
def deployment_id(self):
pass
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
"""
This method should be implemented to return a chat completion from the AI model.
Args:
model (str): the name of the model to use for the chat completion
system (str): the system message string to use for the chat completion
user (str): the user message string to use for the chat completion
temperature (float): the temperature to use for the chat completion
"""
pass

View File

@ -0,0 +1,74 @@
try:
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI, ChatOpenAI
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
pass
from openai import APIError, RateLimitError, Timeout
from retry import retry
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
OPENAI_RETRIES = 5
class LangChainOpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
super().__init__()
self.azure = get_settings().get("OPENAI.API_TYPE", "").lower() == "azure"
# Create a default unused chat object to trigger early validation
self._create_chat(self.deployment_id)
def chat(self, messages: list, model: str, temperature: float):
chat = self._create_chat(self.deployment_id)
return chat.invoke(input=messages, model=model, temperature=temperature)
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
messages = [SystemMessage(content=system), HumanMessage(content=user)]
# get a chat completion from the formatted messages
resp = self.chat(messages, model=model, temperature=temperature)
finish_reason = "completed"
return resp.content, finish_reason
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e
def _create_chat(self, deployment_id=None):
try:
if self.azure:
# using a partial function so we can set the deployment_id later to support fallback_deployments
# but still need to access the other settings now so we can raise a proper exception if they're missing
return AzureChatOpenAI(
openai_api_key=get_settings().openai.key,
openai_api_version=get_settings().openai.api_version,
azure_deployment=deployment_id,
azure_endpoint=get_settings().openai.api_base,
)
else:
# for llms that compatible with openai, should use custom api base
openai_api_base = get_settings().get("OPENAI.API_BASE", None)
if openai_api_base is None or len(openai_api_base) == 0:
return ChatOpenAI(openai_api_key=get_settings().openai.key)
else:
return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base)
except AttributeError as e:
if getattr(e, "name"):
raise ValueError(f"OpenAI {e.name} is required") from e
else:
raise e

View File

@ -0,0 +1,277 @@
import os
import litellm
import openai
import requests
from litellm import acompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from utils.pr_agent.algo import NO_SUPPORT_TEMPERATURE_MODELS, USER_MESSAGE_ONLY_MODELS
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.utils import get_version
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
OPENAI_RETRIES = 5
class LiteLLMAIHandler(BaseAiHandler):
"""
This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file,
and provides a method for performing chat completions using the OpenAI ChatCompletion API.
"""
def __init__(self):
"""
Initializes the OpenAI API key and other settings from a configuration file.
Raises a ValueError if the OpenAI key is missing.
"""
self.azure = False
self.api_base = None
self.repetition_penalty = None
if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
elif 'OPENAI_API_KEY' not in os.environ:
litellm.api_key = "dummy_key"
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete"
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
if get_settings().get("litellm.use_client"):
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
assert litellm_token, "LITELLM_TOKEN is required"
os.environ["LITELLM_TOKEN"] = litellm_token
litellm.use_client = True
if get_settings().get("LITELLM.DROP_PARAMS", None):
litellm.drop_params = get_settings().litellm.drop_params
if get_settings().get("LITELLM.SUCCESS_CALLBACK", None):
litellm.success_callback = get_settings().litellm.success_callback
if get_settings().get("LITELLM.FAILURE_CALLBACK", None):
litellm.failure_callback = get_settings().litellm.failure_callback
if get_settings().get("LITELLM.SERVICE_CALLBACK", None):
litellm.service_callback = get_settings().litellm.service_callback
if get_settings().get("OPENAI.ORG", None):
litellm.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
litellm.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
litellm.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
litellm.api_base = get_settings().openai.api_base
if get_settings().get("ANTHROPIC.KEY", None):
litellm.anthropic_key = get_settings().anthropic.key
if get_settings().get("COHERE.KEY", None):
litellm.cohere_key = get_settings().cohere.key
if get_settings().get("GROQ.KEY", None):
litellm.api_key = get_settings().groq.key
if get_settings().get("REPLICATE.KEY", None):
litellm.replicate_key = get_settings().replicate.key
if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project
litellm.vertex_location = get_settings().get(
"VERTEXAI.VERTEX_LOCATION", None
)
# Google AI Studio
# SEE https://docs.litellm.ai/docs/providers/gemini
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key
# Support deepseek models
if get_settings().get("DEEPSEEK.KEY", None):
os.environ['DEEPSEEK_API_KEY'] = get_settings().get("DEEPSEEK.KEY")
# Models that only use user meessage
self.user_message_only_models = USER_MESSAGE_ONLY_MODELS
# Model that doesn't support temperature argument
self.no_support_temperature_models = NO_SUPPORT_TEMPERATURE_MODELS
def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy()
response_log['system'] = system
response_log['user'] = user
response_log['output'] = resp
response_log['finish_reason'] = finish_reason
if hasattr(self, 'main_pr_language'):
response_log['main_pr_language'] = self.main_pr_language
else:
response_log['main_pr_language'] = 'unknown'
return response_log
def add_litellm_callbacks(selfs, kwargs) -> dict:
captured_extra = []
def capture_logs(message):
# Parsing the log message and context
record = message.record
log_entry = {}
if record.get('extra', None).get('command', None) is not None:
log_entry.update({"command": record['extra']["command"]})
if record.get('extra', {}).get('pr_url', None) is not None:
log_entry.update({"pr_url": record['extra']["pr_url"]})
# Append the log entry to the captured_logs list
captured_extra.append(log_entry)
# Adding the custom sink to Loguru
handler_id = get_logger().add(capture_logs)
get_logger().debug("Capturing logs for litellm callbacks")
get_logger().remove(handler_id)
context = captured_extra[0] if len(captured_extra) > 0 else None
command = context.get("command", "unknown")
pr_url = context.get("pr_url", "unknown")
git_provider = get_settings().config.git_provider
metadata = dict()
callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback
if "langfuse" in callbacks:
metadata.update({
"trace_name": command,
"tags": [git_provider, command, f'version:{get_version()}'],
"trace_metadata": {
"command": command,
"pr_url": pr_url,
},
})
if "langsmith" in callbacks:
metadata.update({
"run_name": command,
"tags": [git_provider, command, f'version:{get_version()}'],
"extra": {
"metadata": {
"command": command,
"pr_url": pr_url,
}
},
})
# Adding the captured logs to the kwargs
kwargs["metadata"] = metadata
return kwargs
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError
stop=stop_after_attempt(OPENAI_RETRIES)
)
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
try:
resp, finish_reason = None, None
deployment_id = self.deployment_id
if self.azure:
model = 'azure/' + model
if 'claude' in model and not system:
system = "No system prompt provided"
get_logger().warning(
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.")
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
if img_path:
try:
# check if the image link is alive
r = requests.head(img_path, allow_redirects=True)
if r.status_code == 404:
error_msg = f"The image link is not [alive](img_path).\nPlease repost the original image as a comment, and send the question again with 'quote reply' (see [instructions](https://pr-agent-docs.codium.ai/tools/ask/#ask-on-images-using-the-pr-code-as-context))."
get_logger().error(error_msg)
return f"{error_msg}", "error"
except Exception as e:
get_logger().error(f"Error fetching image: {img_path}", e)
return f"Error fetching image: {img_path}", "error"
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
{"type": "image_url", "image_url": {"url": img_path}}]
# Currently, some models do not support a separate system and user prompts
if model in self.user_message_only_models or get_settings().config.custom_reasoning_model:
user = f"{system}\n\n\n{user}"
system = ""
get_logger().info(f"Using model {model}, combining system and user prompts")
messages = [{"role": "user", "content": user}]
kwargs = {
"model": model,
"deployment_id": deployment_id,
"messages": messages,
"timeout": get_settings().config.ai_timeout,
"api_base": self.api_base,
}
else:
kwargs = {
"model": model,
"deployment_id": deployment_id,
"messages": messages,
"timeout": get_settings().config.ai_timeout,
"api_base": self.api_base,
}
# Add temperature only if model supports it
if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model:
kwargs["temperature"] = temperature
if get_settings().litellm.get("enable_callbacks", False):
kwargs = self.add_litellm_callbacks(kwargs)
seed = get_settings().config.get("seed", -1)
if temperature > 0 and seed >= 0:
raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0")
elif seed >= 0:
get_logger().info(f"Using fixed seed of {seed}")
kwargs["seed"] = seed
if self.repetition_penalty:
kwargs["repetition_penalty"] = self.repetition_penalty
get_logger().debug("Prompts", artifact={"system": system, "user": user})
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nSystem prompt:\n{system}")
get_logger().info(f"\nUser prompt:\n{user}")
response = await acompletion(**kwargs)
except (openai.APIError, openai.APITimeoutError) as e:
get_logger().warning(f"Error during LLM inference: {e}")
raise
except (openai.RateLimitError) as e:
get_logger().error(f"Rate limit error during LLM inference: {e}")
raise
except (Exception) as e:
get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e
if response is None or len(response["choices"]) == 0:
raise openai.APIError
else:
resp = response["choices"][0]['message']['content']
finish_reason = response["choices"][0]["finish_reason"]
get_logger().debug(f"\nAI response:\n{resp}")
# log the full response for debugging
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
get_logger().debug("Full_response", artifact=response_log)
# for CLI debugging
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nAI response:\n{resp}")
return resp, finish_reason

View File

@ -0,0 +1,67 @@
from os import environ
import openai
from openai import APIError, AsyncOpenAI, RateLimitError, Timeout
from retry import retry
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
OPENAI_RETRIES = 5
class OpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
try:
super().__init__()
environ["OPENAI_API_KEY"] = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
openai.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
environ["OPENAI_BASE_URL"] = get_settings().openai.api_base
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
get_logger().info("System: ", system)
get_logger().info("User: ", user)
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
client = AsyncOpenAI()
chat_completion = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
resp = chat_completion.choices[0].message.content
finish_reason = chat_completion.choices[0].finish_reason
usage = chat_completion.usage
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
except (APIError, Timeout) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
except (RateLimitError) as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise

View File

@ -0,0 +1,34 @@
from base64 import b64decode
import hashlib
class CliArgs:
@staticmethod
def validate_user_args(args: list) -> (bool, str):
try:
if not args:
return True, ""
# decode forbidden args
_encoded_args = 'ZW5hYmxlX2F1dG9fYXBwcm92YWw=:YXBwcm92ZV9wcl9vbl9zZWxmX3Jldmlldw==:YmFzZV91cmw=:dXJs:YXBwX25hbWU=:c2VjcmV0X3Byb3ZpZGVy:Z2l0X3Byb3ZpZGVy:c2tpcF9rZXlz:b3BlbmFpLmtleQ==:QU5BTFlUSUNTX0ZPTERFUg==:dXJp:YXBwX2lk:d2ViaG9va19zZWNyZXQ=:YmVhcmVyX3Rva2Vu:UEVSU09OQUxfQUNDRVNTX1RPS0VO:b3ZlcnJpZGVfZGVwbG95bWVudF90eXBl:cHJpdmF0ZV9rZXk=:bG9jYWxfY2FjaGVfcGF0aA==:ZW5hYmxlX2xvY2FsX2NhY2hl:amlyYV9iYXNlX3VybA==:YXBpX2Jhc2U=:YXBpX3R5cGU=:YXBpX3ZlcnNpb24=:c2tpcF9rZXlz'
forbidden_cli_args = []
for e in _encoded_args.split(':'):
forbidden_cli_args.append(b64decode(e).decode())
# lowercase all forbidden args
for i, _ in enumerate(forbidden_cli_args):
forbidden_cli_args[i] = forbidden_cli_args[i].lower()
if '.' not in forbidden_cli_args[i]:
forbidden_cli_args[i] = '.' + forbidden_cli_args[i]
for arg in args:
if arg.startswith('--'):
arg_word = arg.lower()
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key
for forbidden_arg_word in forbidden_cli_args:
if forbidden_arg_word in arg_word:
return False, forbidden_arg_word
return True, ""
except Exception as e:
return False, str(e)

View File

@ -0,0 +1,65 @@
import fnmatch
import re
from utils.pr_agent.config_loader import get_settings
def filter_ignored(files, platform = 'github'):
"""
Filter out files that match the ignore patterns.
"""
try:
# load regex patterns, and translate glob patterns to regex
patterns = get_settings().ignore.regex
if isinstance(patterns, str):
patterns = [patterns]
glob_setting = get_settings().ignore.glob
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
glob_setting = glob_setting.strip('[]').split(",")
patterns += [fnmatch.translate(glob) for glob in glob_setting]
# compile all valid patterns
compiled_patterns = []
for r in patterns:
try:
compiled_patterns.append(re.compile(r))
except re.error:
pass
# keep filenames that _don't_ match the ignore regex
if files and isinstance(files, list):
for r in compiled_patterns:
if platform == 'github':
files = [f for f in files if (f.filename and not r.match(f.filename))]
elif platform == 'bitbucket':
# files = [f for f in files if (f.new.path and not r.match(f.new.path))]
files_o = []
for f in files:
if hasattr(f, 'new'):
if f.new and f.new.path and not r.match(f.new.path):
files_o.append(f)
continue
if hasattr(f, 'old'):
if f.old and f.old.path and not r.match(f.old.path):
files_o.append(f)
continue
files = files_o
elif platform == 'gitlab':
# files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))]
files_o = []
for f in files:
if 'new_path' in f and f['new_path'] and not r.match(f['new_path']):
files_o.append(f)
continue
if 'old_path' in f and f['old_path'] and not r.match(f['old_path']):
files_o.append(f)
continue
files = files_o
elif platform == 'azure':
files = [f for f in files if not r.match(f)]
except Exception as e:
print(f"Could not filter file list: {e}")
return files

View File

@ -0,0 +1,414 @@
from __future__ import annotations
import re
import traceback
from utils.pr_agent.algo.types import EDIT_TYPE
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
patch_extra_lines_after=0, filename: str = "") -> str:
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
return patch_str
original_file_str = decode_if_bytes(original_file_str)
if not original_file_str:
return patch_str
if should_skip_patch(filename):
return patch_str
try:
extended_patch_str = process_patch_lines(patch_str, original_file_str,
patch_extra_lines_before, patch_extra_lines_after)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
return extended_patch_str
def decode_if_bytes(original_file_str):
if isinstance(original_file_str, (bytes, bytearray)):
try:
return original_file_str.decode('utf-8')
except UnicodeDecodeError:
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
for encoding in encodings_to_try:
try:
return original_file_str.decode(encoding)
except UnicodeDecodeError:
continue
return ""
return original_file_str
def should_skip_patch(filename):
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
if patch_extension_skip_types and filename:
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
return False
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
allow_dynamic_context = get_settings().config.allow_dynamic_context
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines)
patch_lines = patch_str.splitlines()
extended_patch_lines = []
is_valid_hunk = True
start1, size1, start2, size2 = -1, -1, -1, -1
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
try:
for i,line in enumerate(patch_lines):
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
# identify hunk header
if match:
# finish processing previous hunk
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
extended_patch_lines.extend(delta_lines)
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1)
if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0):
def _calc_context_limits(patch_lines_before):
extended_start1 = max(1, start1 - patch_lines_before)
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
extended_start2 = max(1, start2 - patch_lines_before)
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
if extended_start1 - 1 + extended_size1 > len_original_lines:
# we cannot extend beyond the original file
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines
extended_size1 = max(extended_size1 - delta_cap, size1)
extended_size2 = max(extended_size2 - delta_cap, size2)
return extended_start1, extended_size1, extended_start2, extended_size2
if allow_dynamic_context:
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[extended_start1 - 1:start1 - 1]
found_header = False
for i, line, in enumerate(lines_before):
if section_header in line:
found_header = True
# Update start and size in one line each
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i
# get_logger().debug(f"Found section header in line {i} before the hunk")
section_header = ''
break
if not found_header:
# get_logger().debug(f"Section header not found in the extra lines before the hunk")
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)
else:
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context:
for line in delta_lines:
if section_header in line:
section_header = '' # remove section header if it is in the extra delta lines
break
else:
extended_start1 = start1
extended_size1 = size1
extended_start2 = start2
extended_size2 = size2
delta_lines = []
extended_patch_lines.append('')
extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}')
extended_patch_lines.extend(delta_lines) # one to zero based
continue
extended_patch_lines.append(line)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
# finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
# add space at the beginning of each extra line
delta_lines = [f' {line}' for line in delta_lines]
extended_patch_lines.extend(delta_lines)
extended_patch_str = '\n'.join(extended_patch_lines)
return extended_patch_str
def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
"""
Check if the hunk lines match the original file content. We saw cases where the hunk header line doesn't match the original file content, and then
extending the hunk with extra lines before the hunk header can cause the hunk to be invalid.
"""
is_valid_hunk = True
try:
if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
is_valid_hunk = False
get_logger().error(
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content")
except:
pass
return is_valid_hunk
def extract_hunk_headers(match):
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
return section_header, size1, size2, start1, start2
def omit_deletion_hunks(patch_lines) -> str:
"""
Omit deletion hunks from the patch and return the modified patch.
Args:
- patch_lines: a list of strings representing the lines of the patch
Returns:
- A string representing the modified patch with deletion hunks omitted
"""
temp_hunk = []
added_patched = []
add_hunk = False
inside_hunk = False
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match:
# finish previous hunk
if inside_hunk and add_hunk:
added_patched.extend(temp_hunk)
temp_hunk = []
add_hunk = False
temp_hunk.append(line)
inside_hunk = True
else:
temp_hunk.append(line)
if line:
edit_type = line[0]
if edit_type == '+':
add_hunk = True
if inside_hunk and add_hunk:
added_patched.extend(temp_hunk)
return '\n'.join(added_patched)
def handle_patch_deletions(patch: str, original_file_content_str: str,
new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str:
"""
Handle entire file or deletion patches.
This function takes a patch, original file content, new file content, and file name as input.
It handles entire file or deletion patches and returns the modified patch with deletion hunks omitted.
Args:
patch (str): The patch to be handled.
original_file_content_str (str): The original content of the file.
new_file_content_str (str): The new content of the file.
file_name (str): The name of the file.
Returns:
str: The modified patch with deletion hunks omitted.
"""
if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN):
# logic for handling deleted files - don't show patch, just show that the file was deleted
if get_settings().config.verbosity_level > 0:
get_logger().info(f"Processing file: {file_name}, minimizing deletion file")
patch = None # file was deleted
else:
patch_lines = patch.splitlines()
patch_new = omit_deletion_hunks(patch_lines)
if patch != patch_new:
if get_settings().config.verbosity_level > 0:
get_logger().info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new
return patch
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
"""
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
the file.
Args:
patch (str): The patch string to be converted.
file: An object containing the filename of the file being patched.
Returns:
str: A string with line numbers for each hunk, indicating the new and old content of the file.
example output:
## src/file.ts
__new hunk__
881 line1
882 line2
883 line3
887 + line4
888 + line5
889 line6
890 line7
...
__old hunk__
line1
line2
- line3
- line4
line5
line6
...
"""
# if the file was deleted, return a message indicating that the file was deleted
if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED:
return f"\n\n## file '{file.filename.strip()}' was deleted\n"
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
new_content_lines = []
old_content_lines = []
match = None
start1, size1, start2, size2 = -1, -1, -1, -1
prev_header_line = []
header_line = []
for line_i, line in enumerate(patch_lines):
if 'no newline at end of file' in line.lower():
continue
if line.startswith('@@'):
header_line = line
match = RE_HUNK_HEADER.match(line)
if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines
if prev_header_line:
patch_with_lines_str += f'\n{prev_header_line}\n'
is_plus_lines = is_minus_lines = False
if new_content_lines:
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
if old_content_lines:
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if is_minus_lines:
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n"
new_content_lines = []
old_content_lines = []
if match:
prev_header_line = header_line
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
elif line.startswith('+'):
new_content_lines.append(line)
elif line.startswith('-'):
old_content_lines.append(line)
else:
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
continue
elif line_i + 1 == len(patch_lines):
continue
new_content_lines.append(line)
old_content_lines.append(line)
# finishing last hunk
if match and new_content_lines:
patch_with_lines_str += f'\n{header_line}\n'
is_plus_lines = is_minus_lines = False
if new_content_lines:
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
if old_content_lines:
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if is_minus_lines:
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n"
return patch_with_lines_str.rstrip()
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]:
try:
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
selected_lines = ""
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
match = None
start1, size1, start2, size2 = -1, -1, -1, -1
skip_hunk = False
selected_lines_num = 0
for line in patch_lines:
if 'no newline at end of file' in line.lower():
continue
if line.startswith('@@'):
skip_hunk = False
selected_lines_num = 0
header_line = line
match = RE_HUNK_HEADER.match(line)
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
# check if line range is in this hunk
if side.lower() == 'left':
# check if line range is in this hunk
if not (start1 <= line_start <= start1 + size1):
skip_hunk = True
continue
elif side.lower() == 'right':
if not (start2 <= line_start <= start2 + size2):
skip_hunk = True
continue
patch_with_lines_str += f'\n{header_line}\n'
elif not skip_hunk:
if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end:
selected_lines += line + '\n'
if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end:
selected_lines += line + '\n'
patch_with_lines_str += line + '\n'
if not line.startswith('-'): # currently we don't support /ask line for deleted lines
selected_lines_num += 1
except Exception as e:
get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()})
return "", ""
return patch_with_lines_str.rstrip(), selected_lines.rstrip()

View File

@ -0,0 +1,70 @@
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
from typing import Dict
from utils.pr_agent.config_loader import get_settings
def filter_bad_extensions(files):
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = get_settings().bad_extensions.default
if get_settings().config.use_extra_bad_extensions:
bad_extensions += get_settings().bad_extensions.extra
return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)]
def is_valid_file(filename:str, bad_extensions=None) -> bool:
if not filename:
return False
if not bad_extensions:
bad_extensions = get_settings().bad_extensions.default
if get_settings().config.use_extra_bad_extensions:
bad_extensions += get_settings().bad_extensions.extra
return filename.split('.')[-1] not in bad_extensions
def sort_files_by_main_languages(languages: Dict, files: list):
"""
Sort files by their main language, put the files that are in the main language first and the rest files after
"""
# sort languages by their size
languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)]
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
# get all extensions for the languages
main_extensions = []
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
for language in languages_sorted_list:
if language.lower() in language_extension_map:
main_extensions.append(language_extension_map[language.lower()])
else:
main_extensions.append([])
# filter out files bad extensions
files_filtered = filter_bad_extensions(files)
# sort files by their extension, put the files that are in the main extension first
# and the rest files after, map languages_sorted to their respective files
files_sorted = []
rest_files = {}
# if no languages detected, put all files in the "Other" category
if not languages:
files_sorted = [({"language": "Other", "files": list(files_filtered)})]
return files_sorted
main_extensions_flat = []
for ext in main_extensions:
main_extensions_flat.extend(ext)
for extensions, lang in zip(main_extensions, languages_sorted_list): # noqa: B905
tmp = []
for file in files_filtered:
extension_str = f".{file.filename.split('.')[-1]}"
if extension_str in extensions:
tmp.append(file)
else:
if (file.filename not in rest_files) and (extension_str not in main_extensions_flat):
rest_files[file.filename] = file
if len(tmp) > 0:
files_sorted.append({"language": lang, "files": tmp})
files_sorted.append({"language": "Other", "files": list(rest_files.values())})
return files_sorted

View File

@ -0,0 +1,550 @@
from __future__ import annotations
import traceback
from typing import Callable, List, Tuple
from github import RateLimitExceededException
from utils.pr_agent.algo.file_filter import filter_ignored
from utils.pr_agent.algo.git_patch_processing import (
convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions)
from utils.pr_agent.algo.language_handler import sort_files_by_main_languages
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.types import EDIT_TYPE
from utils.pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.git_provider import GitProvider
from utils.pr_agent.log import get_logger
DELETED_FILES_ = "Deleted files:\n"
MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n"
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1500
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 1000
MAX_EXTRA_LINES = 10
def cap_and_log_extra_lines(value, direction) -> int:
if value > MAX_EXTRA_LINES:
get_logger().warning(f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}")
return MAX_EXTRA_LINES
return value
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
model: str,
add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False,
large_pr_handling=False,
return_remaining_files=False):
if disable_extra_lines:
PATCH_EXTRA_LINES_BEFORE = 0
PATCH_EXTRA_LINES_AFTER = 0
else:
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
try:
diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
raise
diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original:
try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}")
except Exception as e:
pass
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
if pr_languages:
try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
except Exception as e:
pass
# generate a standard diff string, with patch extension
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
f"returning full diff.")
return "\n".join(patches_extended)
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.")
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
if large_pr_handling and len(patches_compressed_list) > 1:
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
return "" # return empty string, as we want to generate multiple patches with a different prompt
# return the first patch
patches_compressed = patches_compressed_list[0]
total_tokens_new = total_tokens_list[0]
files_in_patch = files_in_patches_list[0]
# Insert additional information about added, modified, and deleted files if there is enough space
max_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
curr_token = total_tokens_new # == token_handler.count_tokens(final_diff)+token_handler.prompt_tokens
final_diff = "\n".join(patches_compressed)
delta_tokens = 10
added_list_str = modified_list_str = deleted_list_str = ""
unprocessed_files = []
# generate the added, modified, and deleted files lists
if (max_tokens - curr_token) > delta_tokens:
for filename, file_values in file_dict.items():
if filename in files_in_patch:
continue
if file_values['edit_type'] == EDIT_TYPE.ADDED:
unprocessed_files.append(filename)
if not added_list_str:
added_list_str = ADDED_FILES_ + f"\n{filename}"
else:
added_list_str = added_list_str + f"\n{filename}"
elif file_values['edit_type'] in [EDIT_TYPE.MODIFIED, EDIT_TYPE.RENAMED]:
unprocessed_files.append(filename)
if not modified_list_str:
modified_list_str = MORE_MODIFIED_FILES_ + f"\n{filename}"
else:
modified_list_str = modified_list_str + f"\n{filename}"
elif file_values['edit_type'] == EDIT_TYPE.DELETED:
# unprocessed_files.append(filename) # not needed here, because the file was deleted, so no need to process it
if not deleted_list_str:
deleted_list_str = DELETED_FILES_ + f"\n{filename}"
else:
deleted_list_str = deleted_list_str + f"\n{filename}"
# prune the added, modified, and deleted files lists, and add them to the final diff
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
if added_list_str:
final_diff = final_diff + "\n\n" + added_list_str
curr_token += token_handler.count_tokens(added_list_str) + 2
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
if modified_list_str:
final_diff = final_diff + "\n\n" + modified_list_str
curr_token += token_handler.count_tokens(modified_list_str) + 2
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
if deleted_list_str:
final_diff = final_diff + "\n\n" + deleted_list_str
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}")
if not return_remaining_files:
return final_diff
else:
return final_diff, remaining_files_list
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False):
try:
diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
raise
diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original:
try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}")
except Exception as e:
pass
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
if pr_languages:
try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
except Exception as e:
pass
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True)
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool,
patch_extra_lines_before: int = 0,
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
patches_extended_tokens = []
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
patch = file.patch
if not patch:
continue
# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch,
patch_extra_lines_before, patch_extra_lines_after, file.filename)
if not extended_patch:
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
continue
if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
else:
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
# add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch)
patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
total_tokens += patch_tokens
patches_extended_tokens.append(patch_tokens)
patches_extended.append(full_extended_patch)
return patches_extended, total_tokens, patches_extended_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool,
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
deleted_files_list = []
# sort each one of the languages in top_langs by the number of tokens in the diff
sorted_files = []
for lang in top_langs:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
# generate patches for each file, and count tokens
file_dict = {}
for file in sorted_files:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue
# removing delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str,
new_file_content_str, file.filename, file.edit_type)
if patch is None:
if file.filename not in deleted_files_list:
deleted_files_list.append(file.filename)
continue
if convert_hunks_to_line_numbers:
patch = convert_to_hunks_with_lines_numbers(patch, file)
## add AI-summary metadata to the patch (disabled, since we are in the compressed diff)
# if file.ai_file_summary and get_settings().config.get('config.is_auto_command', False):
# patch = add_ai_summary_top_patch(file, patch)
new_patch_tokens = token_handler.count_tokens(patch)
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type}
max_tokens_model = get_max_tokens(model)
# first iteration
files_in_patches_list = []
remaining_files_list = [file.filename for file in sorted_files]
patches_list =[]
total_tokens_list = []
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict,
max_tokens_model, remaining_files_list, token_handler)
patches_list.append(patches)
total_tokens_list.append(total_tokens)
files_in_patches_list.append(files_in_patch_list)
# additional iterations (if needed)
if large_pr_handling:
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
if remaining_files_list:
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list, token_handler)
if patches:
patches_list.append(patches)
total_tokens_list.append(total_tokens)
files_in_patches_list.append(files_in_patch_list)
else:
break
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler):
total_tokens = token_handler.prompt_tokens # initial tokens
patches = []
remaining_files_list_new = []
files_in_patch_list = []
for filename, data in file_dict.items():
if filename not in remaining_files_list_prev:
continue
patch = data['patch']
new_patch_tokens = data['tokens']
edit_type = data['edit_type']
# Hard Stop, no more tokens
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {filename}.")
continue
# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
if get_settings().config.verbosity_level >= 2:
get_logger().warning(f"Patch too large, skipping it: '{filename}'")
remaining_files_list_new.append(filename)
continue
if patch:
if not convert_hunks_to_line_numbers:
patch_final = f"\n\n## File: '{filename.strip()}'\n\n{patch.strip()}\n"
else:
patch_final = "\n\n" + patch.strip()
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
files_in_patch_list.append(filename)
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {filename}")
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
try:
get_logger().debug(
f"Generating prediction with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}"
)
get_settings().set("openai.deployment_id", deployment_id)
return await f(model)
except:
get_logger().warning(
f"Failed to generate prediction with {model}"
)
if i == len(all_models) - 1: # If it's the last iteration
raise Exception(f"Failed to generate prediction with any model of {all_models}")
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.WEAK:
model = get_weak_model()
else:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]
all_models = [model] + fallback_models
return all_models
def _get_all_deployments(all_models: List[str]) -> List[str]:
deployment_id = get_settings().get("openai.deployment_id", None)
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
if not isinstance(fallback_deployments, list) and fallback_deployments:
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments
if len(all_deployments) < len(all_models):
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
f"is less than the number of models ({len(all_models)})")
else:
all_deployments = [deployment_id] * len(all_models)
return all_deployments
def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5) -> List[str]:
"""
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
Args:
git_provider (GitProvider): An object that provides access to Git provider APIs.
token_handler (TokenHandler): An object that handles tokens in the context of a pull request.
model (str): The name of the model.
max_calls (int, optional): The maximum number of calls to retrieve diff files. Defaults to 5.
Returns:
List[str]: A list of final diff strings, split into multiple groups based on the maximum number of tokens allowed for the given model.
Raises:
RateLimitExceededException: If the rate limit for the Git provider API is exceeded.
"""
try:
diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
raise
diff_files = filter_ignored(diff_files)
# Sort files by main language
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
# Sort files within each language group by tokens in descending order
sorted_files = []
for lang in pr_languages:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
# Get the maximum number of extra lines before and after the patch
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
# try first a single run with standard diff string, with patch extension, and no deletions
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=True,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return ["\n".join(patches_extended)] if patches_extended else []
patches = []
final_diff_list = []
total_tokens = token_handler.prompt_tokens
call_number = 1
for file in sorted_files:
if call_number > max_calls:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Reached max calls ({max_calls})")
break
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue
# Remove delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type)
if patch is None:
continue
patch = convert_to_hunks_with_lines_numbers(patch, file)
# add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
patch = add_ai_summary_top_patch(file, patch)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
elif get_settings().config.get('large_patch_policy') == 'clip':
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
new_patch_tokens = token_handler.count_tokens(patch_clipped)
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
else:
get_logger().info(f"Clipped large patch for file: {file.filename}")
patch = patch_clipped
else:
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
total_tokens = token_handler.prompt_tokens
call_number += 1
if call_number > max_calls: # avoid creating new patches
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Reached max calls ({max_calls})")
break
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Call number: {call_number}")
if patch:
patches.append(patch)
total_tokens += new_patch_tokens
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
# Add the last chunk
if patches:
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
return final_diff_list
def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
"""
Adds AI metadata to the diff files based on the PR description files (FilePatchInfo.ai_file_summary).
"""
try:
if not pr_description_files:
get_logger().warning(f"PR description files are empty.")
return
available_files = {pr_file['full_file_name'].strip(): pr_file for pr_file in pr_description_files}
diff_files = git_provider.get_diff_files()
found_any_match = False
for file in diff_files:
filename = file.filename.strip()
if filename in available_files:
file.ai_file_summary = available_files[filename]
found_any_match = True
if not found_any_match:
get_logger().error(f"Failed to find any matching files between PR description and diff files.",
artifact={"pr_description_files": pr_description_files})
except Exception as e:
get_logger().error(f"Failed to add AI metadata to diff files: {e}",
artifact={"traceback": traceback.format_exc()})
def add_ai_summary_top_patch(file, full_extended_patch):
try:
# below every instance of '## File: ...' in the patch, add the ai-summary metadata
full_extended_patch_lines = full_extended_patch.split("\n")
for i, line in enumerate(full_extended_patch_lines):
if line.startswith("## File:") or line.startswith("## file:"):
full_extended_patch_lines.insert(i + 1,
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}")
full_extended_patch = "\n".join(full_extended_patch_lines)
return full_extended_patch
# if no '## File: ...' was found
return full_extended_patch
except Exception as e:
get_logger().error(f"Failed to add AI summary to the top of the patch: {e}",
artifact={"traceback": traceback.format_exc()})
return full_extended_patch

View File

@ -0,0 +1,89 @@
from threading import Lock
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model, get_encoding
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
class TokenEncoder:
_encoder_instance = None
_model = None
_lock = Lock() # Create a lock object
@classmethod
def get_token_encoder(cls):
model = get_settings().config.model
if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance
with cls._lock: # Lock acquisition to ensure thread safety
if cls._encoder_instance is None or model != cls._model:
cls._model = model
cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding(
"cl100k_base")
return cls._encoder_instance
class TokenHandler:
"""
A class for handling tokens in the context of a pull request.
Attributes:
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the
number of tokens in them.
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the
pr_agent.algo module.
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens
method.
"""
def __init__(self, pr=None, vars: dict = {}, system="", user=""):
"""
Initializes the TokenHandler object.
Args:
- pr: The pull request object.
- vars: A dictionary of variables.
- system: The system string.
- user: The user string.
"""
self.encoder = TokenEncoder.get_token_encoder()
if pr is not None:
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
"""
Calculates the number of tokens in the system and user strings.
Args:
- pr: The pull request object.
- encoder: An object of the encoding_for_model class from the tiktoken module.
- vars: A dictionary of variables.
- system: The system string.
- user: The user string.
Returns:
The sum of the number of tokens in the system and user strings.
"""
try:
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens
except Exception as e:
get_logger().error(f"Error in _get_system_user_tokens: {e}")
return 0
def count_tokens(self, patch: str) -> int:
"""
Counts the number of tokens in a given patch string.
Args:
- patch: The patch string.
Returns:
The number of tokens in the patch string.
"""
return len(self.encoder.encode(patch, disallowed_special=()))

View File

@ -0,0 +1,26 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
UNKNOWN = 5
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN
old_filename: str = None
num_plus_lines: int = -1
num_minus_lines: int = -1
language: Optional[str] = None
ai_file_summary: str = None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,96 @@
import argparse
import asyncio
import os
from utils.pr_agent.agent.pr_agent import PRAgent, commands
from utils.pr_agent.algo.utils import get_version
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger, setup_logger
log_level = os.environ.get("LOG_LEVEL", "INFO")
setup_logger(log_level)
def set_parser():
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\
Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>].
For example:
- cli.py --pr_url=... review
- cli.py --pr_url=... describe
- cli.py --pr_url=... improve
- cli.py --pr_url=... ask "write me a poem about this PR"
- cli.py --pr_url=... reflect
- cli.py --issue_url=... similar_issue
Supported commands:
- review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
- ask / ask_question [question] - Ask a question about the PR.
- describe / describe_pr - Modify the PR title and description based on the PR's contents.
- improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback
- reflect - Ask the PR author questions about the PR.
- update_changelog - Update the changelog based on the PR's contents.
- add_docs
- generate_labels
Configuration:
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""")
parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}')
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None)
parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None)
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
return parser
def run_command(pr_url, command):
# Preparing the command
run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}"
args = set_parser().parse_args(run_command_str.split())
# Run the command. Feedback will appear in GitHub PR comments
run(args=args)
def run(inargs=None, args=None):
parser = set_parser()
if not args:
args = parser.parse_args(inargs)
if not args.pr_url and not args.issue_url:
parser.print_help()
return
command = args.command.lower()
get_settings().set("CONFIG.CLI_MODE", True)
async def inner():
if args.issue_url:
result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest))
else:
result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest))
if get_settings().litellm.get("enable_callbacks", False):
# There may be additional events on the event queue from the run above. If there are give them time to complete.
get_logger().debug("Waiting for event queue to complete")
await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()])
return result
result = asyncio.run(inner())
if not result:
parser.print_help()
if __name__ == '__main__':
run()

View File

@ -0,0 +1,23 @@
from utils.pr_agent import cli
from utils.pr_agent.config_loader import get_settings
def main():
# Fill in the following values
provider = "github" # GitHub provider
user_token = "..." # GitHub user token
openai_key = "..." # OpenAI key
pr_url = "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"')
# Setting the configurations
get_settings().set("CONFIG.git_provider", provider)
get_settings().set("openai.key", openai_key)
get_settings().set("github.user_token", user_token)
# Run the command. Feedback will appear in GitHub PR comments
cli.run_command(pr_url, command)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,81 @@
from os.path import abspath, dirname, join
from pathlib import Path
from typing import Optional
from dynaconf import Dynaconf
from starlette_context import context
PR_AGENT_TOML_KEY = 'pr-agent'
current_dir = dirname(abspath(__file__))
global_settings = Dynaconf(
envvar_prefix=False,
merge_enabled=True,
settings_files=[join(current_dir, f) for f in [
"settings/configuration.toml",
"settings/ignore.toml",
"settings/language_extensions.toml",
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_line_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_code_suggestions_reflect_prompts.toml",
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings/pr_custom_labels.toml",
"settings/pr_add_docs.toml",
"settings/custom_labels.toml",
"settings/pr_help_prompts.toml",
"settings/.secrets.toml",
"settings_prod/.secrets.toml",
]]
)
def get_settings():
"""
Retrieves the current settings.
This function attempts to fetch the settings from the starlette_context's context object. If it fails,
it defaults to the global settings defined outside of this function.
Returns:
Dynaconf: The current settings object, either from the context or the global default.
"""
try:
return context["settings"]
except Exception:
return global_settings
# Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Optional[Path]:
"""
Identify project root directory by recursively searching for the .git directory in the parent directories.
"""
cwd = Path.cwd().resolve()
no_way_up = False
while not no_way_up:
no_way_up = cwd == cwd.parent
if (cwd / ".git").is_dir():
return cwd
cwd = cwd.parent
return None
def _find_pyproject() -> Optional[Path]:
"""
Search for file pyproject.toml in the repository root.
"""
repo_root = _find_repository_root()
if repo_root:
pyproject = repo_root / "pyproject.toml"
return pyproject if pyproject.is_file() else None
return None
pyproject_path = _find_pyproject()
if pyproject_path is not None:
get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')

View File

@ -0,0 +1,64 @@
from starlette_context import context
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from utils.pr_agent.git_providers.bitbucket_server_provider import \
BitbucketServerProvider
from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider
from utils.pr_agent.git_providers.gerrit_provider import GerritProvider
from utils.pr_agent.git_providers.git_provider import GitProvider
from utils.pr_agent.git_providers.github_provider import GithubProvider
from utils.pr_agent.git_providers.gitlab_provider import GitLabProvider
from utils.pr_agent.git_providers.local_git_provider import LocalGitProvider
_GIT_PROVIDERS = {
'github': GithubProvider,
'gitlab': GitLabProvider,
'bitbucket': BitbucketProvider,
'bitbucket_server': BitbucketServerProvider,
'azure': AzureDevopsProvider,
'codecommit': CodeCommitProvider,
'local': LocalGitProvider,
'gerrit': GerritProvider,
}
def get_git_provider():
try:
provider_id = get_settings().config.git_provider
except AttributeError as e:
raise ValueError("git_provider is a required attribute in the configuration file") from e
if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}")
return _GIT_PROVIDERS[provider_id]
def get_git_provider_with_context(pr_url) -> GitProvider:
"""
Get a GitProvider instance for the given PR URL. If the GitProvider instance is already in the context, return it.
"""
is_context_env = None
try:
is_context_env = context.get("settings", None)
except Exception:
pass # we are not in a context environment (CLI)
# check if context["git_provider"]["pr_url"] exists
if is_context_env and context.get("git_provider", {}).get("pr_url", {}):
git_provider = context["git_provider"]["pr_url"]
# possibly check if the git_provider is still valid, or if some reset is needed
# ...
return git_provider
else:
try:
provider_id = get_settings().config.git_provider
if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}")
git_provider = _GIT_PROVIDERS[provider_id](pr_url)
if is_context_env:
context["git_provider"] = {pr_url: git_provider}
return git_provider
except Exception as e:
raise ValueError(f"Failed to get git provider for {pr_url}") from e

View File

@ -0,0 +1,620 @@
import os
from typing import Optional, Tuple
from urllib.parse import urlparse
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file
from ..algo.utils import (PRDescriptionHeader, find_line_number_of_relevant_line_in_file,
load_large_diff)
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import GitProvider
AZURE_DEVOPS_AVAILABLE = True
ADO_APP_CLIENT_DEFAULT_ID = "499b84ac-1321-427f-aa17-267ca6975798/.default"
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000-1
try:
# noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
from azure.devops.connection import Connection
# noinspection PyUnresolvedReferences
from azure.devops.v7_1.git.models import (Comment, CommentThread,
GitPullRequest,
GitPullRequestIterationChanges,
GitVersionDescriptor)
# noinspection PyUnresolvedReferences
from azure.identity import DefaultAzureCredential
from msrest.authentication import BasicAuthentication
except ImportError:
AZURE_DEVOPS_AVAILABLE = False
class AzureDevopsProvider(GitProvider):
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
if not AZURE_DEVOPS_AVAILABLE:
raise ImportError(
"Azure DevOps provider is not available. Please install the required dependencies."
)
self.azure_devops_client = self._get_azure_devops_client()
self.diff_files = None
self.workspace_slug = None
self.repo_slug = None
self.repo = None
self.pr_num = None
self.pr = None
self.temp_comments = []
self.incremental = incremental
if pr_url:
self.set_pr(pr_url)
def publish_code_suggestions(self, code_suggestions: list) -> bool:
"""
Publishes code suggestions as comments on the PR.
"""
post_parameters_list = []
for suggestion in code_suggestions:
body = suggestion['body']
relevant_file = suggestion['relevant_file']
relevant_lines_start = suggestion['relevant_lines_start']
relevant_lines_end = suggestion['relevant_lines_end']
if not relevant_lines_start or relevant_lines_start == -1:
get_logger().warning(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
continue
if relevant_lines_end < relevant_lines_start:
get_logger().warning(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}")
continue
if relevant_lines_end > relevant_lines_start:
post_parameters = {
"body": body,
"path": relevant_file,
"line": relevant_lines_end,
"start_line": relevant_lines_start,
"start_side": "RIGHT",
}
else: # API is different for single line comments
post_parameters = {
"body": body,
"path": relevant_file,
"line": relevant_lines_start,
"side": "RIGHT",
}
post_parameters_list.append(post_parameters)
if not post_parameters_list:
return False
for post_parameters in post_parameters_list:
try:
comment = Comment(content=post_parameters["body"], comment_type=1)
thread = CommentThread(comments=[comment],
thread_context={
"filePath": post_parameters["path"],
"rightFileStart": {
"line": post_parameters["start_line"],
"offset": 1,
},
"rightFileEnd": {
"line": post_parameters["line"],
"offset": 1,
},
})
self.azure_devops_client.create_thread(
comment_thread=thread,
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num
)
except Exception as e:
get_logger().warning(f"Azure failed to publish code suggestion, error: {e}")
return True
def get_pr_description_full(self) -> str:
return self.pr.description
def edit_comment(self, comment, body: str):
try:
self.azure_devops_client.update_comment(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
thread_id=comment["thread_id"],
comment_id=comment["comment_id"],
comment=Comment(content=body),
project=self.workspace_slug,
)
except Exception as e:
get_logger().exception(f"Failed to edit comment, error: {e}")
def remove_comment(self, comment):
try:
self.azure_devops_client.delete_comment(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
thread_id=comment["thread_id"],
comment_id=comment["comment_id"],
project=self.workspace_slug,
)
except Exception as e:
get_logger().exception(f"Failed to remove comment, error: {e}")
def publish_labels(self, pr_types):
try:
for pr_type in pr_types:
self.azure_devops_client.create_pull_request_label(
label={"name": pr_type},
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
)
except Exception as e:
get_logger().warning(f"Failed to publish labels, error: {e}")
def get_pr_labels(self, update=False):
try:
labels = self.azure_devops_client.get_pull_request_labels(
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
)
return [label.name for label in labels]
except Exception as e:
get_logger().exception(f"Failed to get labels, error: {e}")
return []
def is_supported(self, capability: str) -> bool:
if capability in [
"get_issue_comments",
]:
return False
return True
def set_pr(self, pr_url: str):
self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr()
def get_repo_settings(self):
try:
contents = self.azure_devops_client.get_item_content(
repository_id=self.repo_slug,
project=self.workspace_slug,
download=False,
include_content_metadata=False,
include_content=True,
path=".pr_agent.toml",
)
return list(contents)[0]
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to get repo settings, error: {e}")
return ""
def get_files(self):
files = []
for i in self.azure_devops_client.get_pull_request_commits(
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
):
changes_obj = self.azure_devops_client.get_changes(
project=self.workspace_slug,
repository_id=self.repo_slug,
commit_id=i.commit_id,
)
for c in changes_obj.changes:
files.append(c["item"]["path"])
return list(set(files))
def get_diff_files(self) -> list[FilePatchInfo]:
try:
if self.diff_files:
return self.diff_files
base_sha = self.pr.last_merge_target_commit
head_sha = self.pr.last_merge_source_commit
# Get PR iterations
iterations = self.azure_devops_client.get_pull_request_iterations(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
project=self.workspace_slug
)
changes = None
if iterations:
iteration_id = iterations[-1].id # Get the last iteration (most recent changes)
# Get changes for the iteration
changes = self.azure_devops_client.get_pull_request_iteration_changes(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
iteration_id=iteration_id,
project=self.workspace_slug
)
diff_files = []
diffs = []
diff_types = {}
if changes:
for change in changes.change_entries:
item = change.additional_properties.get('item', {})
path = item.get('path', None)
if path:
diffs.append(path)
diff_types[path] = change.additional_properties.get('changeType', 'Unknown')
# wrong implementation - gets all the files that were changed in any commit in the PR
# commits = self.azure_devops_client.get_pull_request_commits(
# project=self.workspace_slug,
# repository_id=self.repo_slug,
# pull_request_id=self.pr_num,
# )
#
# diff_files = []
# diffs = []
# diff_types = {}
# for c in commits:
# changes_obj = self.azure_devops_client.get_changes(
# project=self.workspace_slug,
# repository_id=self.repo_slug,
# commit_id=c.commit_id,
# )
# for i in changes_obj.changes:
# if i["item"]["gitObjectType"] == "tree":
# continue
# diffs.append(i["item"]["path"])
# diff_types[i["item"]["path"]] = i["changeType"]
#
# diffs = list(set(diffs))
diffs_original = diffs
diffs = filter_ignored(diffs_original, 'azure')
if diffs_original != diffs:
try:
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
{"files": diffs_original, # diffs is just a list of names
"filtered_files": diffs})
except Exception:
pass
invalid_files_names = []
for file in diffs:
if not is_valid_file(file):
invalid_files_names.append(file)
continue
version = GitVersionDescriptor(
version=head_sha.commit_id, version_type="commit"
)
try:
new_file_content_str = self.azure_devops_client.get_item(
repository_id=self.repo_slug,
path=file,
project=self.workspace_slug,
version_descriptor=version,
download=False,
include_content=True,
)
new_file_content_str = new_file_content_str.content
except Exception as error:
get_logger().error(f"Failed to retrieve new file content of {file} at version {version}", error=error)
# get_logger().error(
# "Failed to retrieve new file content of %s at version %s. Error: %s",
# file,
# version,
# str(error),
# )
new_file_content_str = ""
edit_type = EDIT_TYPE.MODIFIED
if diff_types[file] == "add":
edit_type = EDIT_TYPE.ADDED
elif diff_types[file] == "delete":
edit_type = EDIT_TYPE.DELETED
elif "rename" in diff_types[file]: # diff_type can be `rename` | `edit, rename`
edit_type = EDIT_TYPE.RENAMED
version = GitVersionDescriptor(
version=base_sha.commit_id, version_type="commit"
)
if edit_type == EDIT_TYPE.ADDED or edit_type == EDIT_TYPE.RENAMED:
original_file_content_str = ""
else:
try:
original_file_content_str = self.azure_devops_client.get_item(
repository_id=self.repo_slug,
path=file,
project=self.workspace_slug,
version_descriptor=version,
download=False,
include_content=True,
)
original_file_content_str = original_file_content_str.content
except Exception as error:
get_logger().error(f"Failed to retrieve original file content of {file} at version {version}", error=error)
original_file_content_str = ""
patch = load_large_diff(
file, new_file_content_str, original_file_content_str, show_warning=False
).rstrip()
# count number of lines added and removed
patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
diff_files.append(
FilePatchInfo(
original_file_content_str,
new_file_content_str,
patch=patch,
filename=file,
edit_type=edit_type,
num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines,
)
)
get_logger().info(f"Invalid files: {invalid_files_names}")
self.diff_files = diff_files
return diff_files
except Exception as e:
get_logger().exception(f"Failed to get diff files, error: {e}")
return []
def publish_comment(self, pr_comment: str, is_temporary: bool = False, thread_context=None):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
return None
comment = Comment(content=pr_comment)
thread = CommentThread(comments=[comment], thread_context=thread_context, status=5)
thread_response = self.azure_devops_client.create_thread(
comment_thread=thread,
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
)
response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id}
if is_temporary:
self.temp_comments.append(response)
return response
def publish_description(self, pr_title: str, pr_body: str):
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
usage_guide_text='<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
ind = pr_body.find(usage_guide_text)
if ind != -1:
pr_body = pr_body[:ind]
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
changes_walkthrough_text = PRDescriptionHeader.CHANGES_WALKTHROUGH.value
ind = pr_body.find(changes_walkthrough_text)
if ind != -1:
pr_body = pr_body[:ind]
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
trunction_message = " ... (description truncated due to length limit)"
pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message
get_logger().warning("PR description was truncated due to length limit")
try:
updated_pr = GitPullRequest()
updated_pr.title = pr_title
updated_pr.description = pr_body
self.azure_devops_client.update_pull_request(
project=self.workspace_slug,
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
git_pull_request_to_update=updated_pr,
)
except Exception as e:
get_logger().exception(
f"Could not update pull request {self.pr_num} description: {e}"
)
def remove_initial_comment(self):
try:
for comment in self.temp_comments:
self.remove_comment(comment)
except Exception as e:
get_logger().exception(f"Failed to remove temp comments, error: {e}")
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position)
if position == -1:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
subject_type = "FILE"
else:
subject_type = "LINE"
path = relevant_file.strip()
return dict(body=body, path=path, position=position, absolute_position=absolute_position) if subject_type == "LINE" else {}
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
overall_success = True
for comment in comments:
try:
self.publish_comment(comment["body"],
thread_context={
"filePath": comment["path"],
"rightFileStart": {
"line": comment["absolute_position"],
"offset": comment["position"],
},
"rightFileEnd": {
"line": comment["absolute_position"],
"offset": comment["position"],
},
})
if get_settings().config.verbosity_level >= 2:
get_logger().info(
f"Published code suggestion on {self.pr_num} at {comment['path']}"
)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to publish code suggestion, error: {e}")
overall_success = False
return overall_success
def get_title(self):
return self.pr.title
def get_languages(self):
languages = []
files = self.azure_devops_client.get_items(
project=self.workspace_slug,
repository_id=self.repo_slug,
recursion_level="Full",
include_content_metadata=True,
include_links=False,
download=False,
)
for f in files:
if f.git_object_type == "blob":
file_name, file_extension = os.path.splitext(f.path)
languages.append(file_extension[1:])
extension_counts = {}
for ext in languages:
if ext != "":
extension_counts[ext] = extension_counts.get(ext, 0) + 1
total_extensions = sum(extension_counts.values())
extension_percentages = {
ext: (count / total_extensions) * 100
for ext, count in extension_counts.items()
}
return extension_percentages
def get_pr_branch(self):
pr_info = self.azure_devops_client.get_pull_request_by_id(
project=self.workspace_slug, pull_request_id=self.pr_num
)
source_branch = pr_info.source_ref_name.split("/")[-1]
return source_branch
def get_user_id(self):
return 0
def get_issue_comments(self):
threads = self.azure_devops_client.get_threads(repository_id=self.repo_slug, pull_request_id=self.pr_num, project=self.workspace_slug)
threads.reverse()
comment_list = []
for thread in threads:
for comment in thread.comments:
if comment.content and comment not in comment_list:
comment.body = comment.content
comment.thread_id = thread.id
comment_list.append(comment)
return comment_list
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return True
@staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, str, int]:
parsed_url = urlparse(pr_url)
path_parts = parsed_url.path.strip("/").split("/")
if "pullrequest" not in path_parts:
raise ValueError(
"The provided URL does not appear to be a Azure DevOps PR URL"
)
if len(path_parts) == 6: # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
workspace_slug = path_parts[1]
repo_slug = path_parts[3]
pr_number = int(path_parts[5])
elif len(path_parts) == 5: # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
workspace_slug = path_parts[0]
repo_slug = path_parts[2]
pr_number = int(path_parts[4])
else:
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL")
return workspace_slug, repo_slug, pr_number
@staticmethod
def _get_azure_devops_client():
org = get_settings().azure_devops.get("org", None)
pat = get_settings().azure_devops.get("pat", None)
if not org:
raise ValueError("Azure DevOps organization is required")
if pat:
auth_token = pat
else:
try:
# try to use azure default credentials
# see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
# for usage and env var configuration of user-assigned managed identity, local machine auth etc.
get_logger().info("No PAT found in settings, trying to use Azure Default Credentials.")
credentials = DefaultAzureCredential()
accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID)
auth_token = accessToken.token
except Exception as e:
get_logger().error(f"No PAT found in settings, and Azure Default Authentication failed, error: {e}")
raise
credentials = BasicAuthentication("", auth_token)
credentials = BasicAuthentication("", auth_token)
azure_devops_connection = Connection(base_url=org, creds=credentials)
azure_devops_client = azure_devops_connection.clients.get_git_client()
return azure_devops_client
def _get_repo(self):
if self.repo is None:
self.repo = self.azure_devops_client.get_repository(
project=self.workspace_slug, repository_id=self.repo_slug
)
return self.repo
def _get_pr(self):
self.pr = self.azure_devops_client.get_pull_request_by_id(
pull_request_id=self.pr_num, project=self.workspace_slug
)
return self.pr
def get_commit_messages(self):
return "" # not implemented yet
def get_pr_id(self):
try:
pr_id = f"{self.workspace_slug}/{self.repo_slug}/{self.pr_num}"
return pr_id
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to get pr id, error: {e}")
return ""
def publish_file_comments(self, file_comments: list) -> bool:
pass

View File

@ -0,0 +1,561 @@
import difflib
import json
import re
from typing import Optional, Tuple
from urllib.parse import urlparse
import requests
from atlassian.bitbucket import Cloud
from starlette_context import context
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file
from ..algo.utils import find_line_number_of_relevant_line_in_file
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
def _gef_filename(diff):
if diff.new.path:
return diff.new.path
return diff.old.path
class BitbucketProvider(GitProvider):
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
s = requests.Session()
try:
bearer = context.get("bitbucket_bearer_token", None)
s.headers["Authorization"] = f"Bearer {bearer}"
except Exception:
s.headers[
"Authorization"
] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
s.headers["Content-Type"] = "application/json"
self.headers = s.headers
self.bitbucket_client = Cloud(session=s)
self.max_comment_length = 31000
self.workspace_slug = None
self.repo_slug = None
self.repo = None
self.pr_num = None
self.pr = None
self.pr_url = pr_url
self.temp_comments = []
self.incremental = incremental
self.diff_files = None
self.git_files = None
if pr_url:
self.set_pr(pr_url)
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"]["comments"]["href"]
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"]['self']['href']
def get_repo_settings(self):
try:
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{self.pr.destination_branch}/.pr_agent.toml")
response = requests.request("GET", url, headers=self.headers)
if response.status_code == 404: # not found
return ""
contents = response.text.encode('utf-8')
return contents
except Exception:
return ""
def publish_code_suggestions(self, code_suggestions: list) -> bool:
"""
Publishes code suggestions as comments on the PR.
"""
post_parameters_list = []
for suggestion in code_suggestions:
body = suggestion["body"]
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
if original_suggestion:
try:
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'),
improved_code.split('\n'), n=999)
patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
except Exception as e:
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
continue
relevant_file = suggestion["relevant_file"]
relevant_lines_start = suggestion["relevant_lines_start"]
relevant_lines_end = suggestion["relevant_lines_end"]
if not relevant_lines_start or relevant_lines_start == -1:
get_logger().exception(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end < relevant_lines_start:
get_logger().exception(
f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end > relevant_lines_start:
post_parameters = {
"body": body,
"path": relevant_file,
"line": relevant_lines_end,
"start_line": relevant_lines_start,
"start_side": "RIGHT",
}
else: # API is different for single line comments
post_parameters = {
"body": body,
"path": relevant_file,
"line": relevant_lines_start,
"side": "RIGHT",
}
post_parameters_list.append(post_parameters)
try:
self.publish_inline_comments(post_parameters_list)
return True
except Exception as e:
get_logger().error(f"Bitbucket failed to publish code suggestion, error: {e}")
return False
def publish_file_comments(self, file_comments: list) -> bool:
pass
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'publish_inline_comments', 'get_labels', 'gfm_markdown',
'publish_file_comments']:
return False
return True
def set_pr(self, pr_url: str):
self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr()
def get_files(self):
try:
git_files = context.get("git_files", None)
if git_files:
return git_files
self.git_files = [_gef_filename(diff) for diff in self.pr.diffstat()]
context["git_files"] = self.git_files
return self.git_files
except Exception:
if not self.git_files:
self.git_files = [_gef_filename(diff) for diff in self.pr.diffstat()]
return self.git_files
def get_diff_files(self) -> list[FilePatchInfo]:
if self.diff_files:
return self.diff_files
diffs_original = list(self.pr.diffstat())
diffs = filter_ignored(diffs_original, 'bitbucket')
if diffs != diffs_original:
try:
names_original = [d.new.path for d in diffs_original]
names_kept = [d.new.path for d in diffs]
names_filtered = list(set(names_original) - set(names_kept))
get_logger().info(f"Filtered out [ignore] files for PR", extra={
'original_files': names_original,
'names_kept': names_kept,
'names_filtered': names_filtered
})
except Exception as e:
pass
# get the pr patches
try:
pr_patches = self.pr.diff()
except Exception as e:
# Try different encodings if UTF-8 fails
get_logger().warning(f"Failed to decode PR patch with utf-8, error: {e}")
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
pr_patches = None
for encoding in encodings_to_try:
try:
pr_patches = self.pr.diff(encoding=encoding)
get_logger().info(f"Successfully decoded PR patch with encoding {encoding}")
break
except UnicodeDecodeError:
continue
if pr_patches is None:
raise ValueError(f"Failed to decode PR patch with encodings {encodings_to_try}")
diff_split = ["diff --git" + x for x in pr_patches.split("diff --git") if x.strip()]
# filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs'
if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split):
diff_split = [diff_split[i] for i in range(len(diff_split)) if diffs_original[i] in diffs]
if len(diff_split) != len(diffs):
get_logger().error(f"Error - failed to split the diff into {len(diffs)} parts")
return []
# bitbucket diff has a header for each file, we need to remove it:
# "diff --git filename
# new file mode 100644 (optional)
# index caa56f0..61528d7 100644
# --- a/pr_agent/cli_pip.py
# +++ b/pr_agent/cli_pip.py
# @@ -... @@"
for i, _ in enumerate(diff_split):
diff_split_lines = diff_split[i].splitlines()
if (len(diff_split_lines) >= 6) and \
((diff_split_lines[2].startswith("---") and
diff_split_lines[3].startswith("+++") and
diff_split_lines[4].startswith("@@")) or
(diff_split_lines[3].startswith("---") and # new or deleted file
diff_split_lines[4].startswith("+++") and
diff_split_lines[5].startswith("@@"))):
diff_split[i] = "\n".join(diff_split_lines[4:])
else:
if diffs[i].data.get('lines_added', 0) == 0 and diffs[i].data.get('lines_removed', 0) == 0:
diff_split[i] = ""
elif len(diff_split_lines) <= 3:
diff_split[i] = ""
get_logger().info(f"Disregarding empty diff for file {_gef_filename(diffs[i])}")
else:
get_logger().warning(f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}")
diff_split[i] = ""
invalid_files_names = []
diff_files = []
counter_valid = 0
# get full files
for index, diff in enumerate(diffs):
file_path = _gef_filename(diff)
if not is_valid_file(file_path):
invalid_files_names.append(file_path)
continue
try:
counter_valid += 1
if get_settings().get("bitbucket_app.avoid_full_files", False):
original_file_content_str = ""
new_file_content_str = ""
elif counter_valid < MAX_FILES_ALLOWED_FULL // 2: # factor 2 because bitbucket has limited API calls
if diff.old.get_data("links"):
original_file_content_str = self._get_pr_file_content(
diff.old.get_data("links")['self']['href'])
else:
original_file_content_str = ""
if diff.new.get_data("links"):
new_file_content_str = self._get_pr_file_content(diff.new.get_data("links")['self']['href'])
else:
new_file_content_str = ""
else:
if counter_valid == MAX_FILES_ALLOWED_FULL // 2:
get_logger().info(
f"Bitbucket too many files in PR, will avoid loading full content for rest of files")
original_file_content_str = ""
new_file_content_str = ""
except Exception as e:
get_logger().exception(f"Error - bitbucket failed to get file content, error: {e}")
original_file_content_str = ""
new_file_content_str = ""
file_patch_canonic_structure = FilePatchInfo(
original_file_content_str,
new_file_content_str,
diff_split[index],
file_path,
)
if diff.data['status'] == 'added':
file_patch_canonic_structure.edit_type = EDIT_TYPE.ADDED
elif diff.data['status'] == 'removed':
file_patch_canonic_structure.edit_type = EDIT_TYPE.DELETED
elif diff.data['status'] == 'modified':
file_patch_canonic_structure.edit_type = EDIT_TYPE.MODIFIED
elif diff.data['status'] == 'renamed':
file_patch_canonic_structure.edit_type = EDIT_TYPE.RENAMED
diff_files.append(file_patch_canonic_structure)
if invalid_files_names:
get_logger().info(f"Disregarding files with invalid extensions:\n{invalid_files_names}")
self.diff_files = diff_files
return diff_files
def get_latest_commit_url(self):
return self.pr.data['source']['commit']['links']['html']['href']
def get_comment_url(self, comment):
return comment.data['links']['html']['href']
def publish_persistent_comment(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
try:
for comment in self.pr.comments():
body = comment.raw
if initial_header in body:
latest_commit_url = self.get_latest_commit_url()
comment_url = self.get_comment_url(comment)
if update_header:
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
else:
pr_comment_updated = pr_comment
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
d = {"content": {"raw": pr_comment_updated}}
response = comment._update_data(comment.put(None, data=d))
if final_update_message:
self.publish_comment(
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
return
except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}")
pass
self.publish_comment(pr_comment)
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
return None
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length)
comment = self.pr.comment(pr_comment)
if is_temporary:
self.temp_comments.append(comment["id"])
return comment
def edit_comment(self, comment, body: str):
try:
body = self.limit_output_characters(body, self.max_comment_length)
comment.update(body)
except Exception as e:
get_logger().exception(f"Failed to update comment, error: {e}")
def remove_initial_comment(self):
try:
for comment in self.temp_comments:
self.remove_comment(comment)
except Exception as e:
get_logger().exception(f"Failed to remove temp comments, error: {e}")
def remove_comment(self, comment):
try:
self.pr.delete(f"comments/{comment}")
except Exception as e:
get_logger().exception(f"Failed to remove comment, error: {e}")
# function to create_inline_comment
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
body = self.limit_output_characters(body, self.max_comment_length)
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position)
if position == -1:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
subject_type = "FILE"
else:
subject_type = "LINE"
path = relevant_file.strip()
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
comment = self.limit_output_characters(comment, self.max_comment_length)
payload = json.dumps({
"content": {
"raw": comment,
},
"inline": {
"to": from_line,
"path": file
},
})
response = requests.request(
"POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers
)
return response
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
if relevant_line_start == -1:
link = f"{self.pr_url}/#L{relevant_file}"
else:
link = f"{self.pr_url}/#L{relevant_file}T{relevant_line_start}"
return link
def generate_link_to_relevant_line_number(self, suggestion) -> str:
try:
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
relevant_line_str = suggestion['relevant_line'].rstrip()
if not relevant_line_str:
return ""
diff_files = self.get_diff_files()
position, absolute_position = find_line_number_of_relevant_line_in_file \
(diff_files, relevant_file, relevant_line_str)
if absolute_position != -1 and self.pr_url:
link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}"
return link
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link, error: {e}")
return ""
def publish_inline_comments(self, comments: list[dict]):
for comment in comments:
if 'position' in comment:
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
elif 'start_line' in comment: # multi-line comment
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
elif 'line' in comment: # single-line comment
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
else:
get_logger().error(f"Could not publish inline comment {comment}")
def get_title(self):
return self.pr.title
def get_languages(self):
languages = {self._get_repo().get_data("language"): 0}
return languages
def get_pr_branch(self):
return self.pr.source_branch
def get_pr_owner_id(self) -> str | None:
return self.workspace_slug
def get_pr_description_full(self):
return self.pr.description
def get_user_id(self):
return 0
def get_issue_comments(self):
raise NotImplementedError(
"Bitbucket provider does not support issue comments yet"
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return True
@staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url)
if "bitbucket.org" not in parsed_url.netloc:
raise ValueError("The provided URL is not a valid Bitbucket URL")
path_parts = parsed_url.path.strip("/").split("/")
if len(path_parts) < 4 or path_parts[2] != "pull-requests":
raise ValueError(
"The provided URL does not appear to be a Bitbucket PR URL"
)
workspace_slug = path_parts[0]
repo_slug = path_parts[1]
try:
pr_number = int(path_parts[3])
except ValueError as e:
raise ValueError("Unable to convert PR number to integer") from e
return workspace_slug, repo_slug, pr_number
def _get_repo(self):
if self.repo is None:
self.repo = self.bitbucket_client.workspaces.get(
self.workspace_slug
).repositories.get(self.repo_slug)
return self.repo
def _get_pr(self):
return self._get_repo().pullrequests.get(self.pr_num)
def get_pr_file_content(self, file_path: str, branch: str) -> str:
try:
if branch == self.pr.source_branch:
branch = self.pr.data["source"]["commit"]["hash"]
elif branch == self.pr.destination_branch:
branch = self.pr.data["destination"]["commit"]["hash"]
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{branch}/{file_path}")
response = requests.request("GET", url, headers=self.headers)
if response.status_code == 404: # not found
return ""
contents = response.text
return contents
except Exception:
return ""
def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None:
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/")
if not message:
if contents:
message = f"Update {file_path}"
else:
message = f"Create {file_path}"
files = {file_path: contents}
data = {
"message": message,
"branch": branch
}
headers = {'Authorization': self.headers['Authorization']} if 'Authorization' in self.headers else {}
try:
requests.request("POST", url, headers=headers, data=data, files=files)
except Exception:
get_logger().exception(f"Failed to create empty file {file_path} in branch {branch}")
def _get_pr_file_content(self, remote_link: str):
try:
response = requests.request("GET", remote_link, headers=self.headers)
if response.status_code == 404: # not found
return ""
contents = response.text
return contents
except Exception:
return ""
def get_commit_messages(self):
return "" # not implemented yet
# bitbucket does not support labels
def publish_description(self, pr_title: str, description: str):
payload = json.dumps({
"description": description,
"title": pr_title
})
response = requests.request("PUT", self.bitbucket_pull_request_api_url, headers=self.headers, data=payload)
try:
if response.status_code != 200:
get_logger().info(f"Failed to update description, error code: {response.status_code}")
except:
pass
return response
# bitbucket does not support labels
def publish_labels(self, pr_types: list):
pass
# bitbucket does not support labels
def get_pr_labels(self, update=False):
pass

View File

@ -0,0 +1,483 @@
import difflib
import re
from packaging.version import parse as parse_version
from typing import Optional, Tuple
from urllib.parse import quote_plus, urlparse
from atlassian.bitbucket import Bitbucket
from requests.exceptions import HTTPError
from ..algo.git_patch_processing import decode_if_bytes
from ..algo.language_handler import is_valid_file
from ..algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.utils import (find_line_number_of_relevant_line_in_file,
load_large_diff)
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import GitProvider
class BitbucketServerProvider(GitProvider):
def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False,
bitbucket_client: Optional[Bitbucket] = None,
):
self.bitbucket_server_url = None
self.workspace_slug = None
self.repo_slug = None
self.repo = None
self.pr_num = None
self.pr = None
self.pr_url = pr_url
self.temp_comments = []
self.incremental = incremental
self.diff_files = None
self.bitbucket_pull_request_api_url = pr_url
self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url)
self.bitbucket_client = bitbucket_client or Bitbucket(url=self.bitbucket_server_url,
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN",
None))
try:
self.bitbucket_api_version = parse_version(self.bitbucket_client.get("rest/api/1.0/application-properties").get('version'))
except Exception:
self.bitbucket_api_version = None
if pr_url:
self.set_pr(pr_url)
def get_repo_settings(self):
try:
content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch())
return content
except Exception as e:
if isinstance(e, HTTPError):
if e.response.status_code == 404: # not found
return ""
get_logger().error(f"Failed to load .pr_agent.toml file, error: {e}")
return ""
def get_pr_id(self):
return self.pr_num
def publish_code_suggestions(self, code_suggestions: list) -> bool:
"""
Publishes code suggestions as comments on the PR.
"""
post_parameters_list = []
for suggestion in code_suggestions:
body = suggestion["body"]
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
if original_suggestion:
try:
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'),
improved_code.split('\n'), n=999)
patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
except Exception as e:
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
continue
relevant_file = suggestion["relevant_file"]
relevant_lines_start = suggestion["relevant_lines_start"]
relevant_lines_end = suggestion["relevant_lines_end"]
if not relevant_lines_start or relevant_lines_start == -1:
get_logger().warning(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end < relevant_lines_start:
get_logger().warning(
f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}"
)
continue
if relevant_lines_end > relevant_lines_start:
# Bitbucket does not support multi-line suggestions so use a code block instead - https://jira.atlassian.com/browse/BSERV-4553
body = body.replace("```suggestion", "```")
post_parameters = {
"body": body,
"path": relevant_file,
"line": relevant_lines_end,
"start_line": relevant_lines_start,
"start_side": "RIGHT",
}
else: # API is different for single line comments
post_parameters = {
"body": body,
"path": relevant_file,
"line": relevant_lines_start,
"side": "RIGHT",
}
post_parameters_list.append(post_parameters)
try:
self.publish_inline_comments(post_parameters_list)
return True
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to publish code suggestion, error: {e}")
return False
def publish_file_comments(self, file_comments: list) -> bool:
pass
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'get_labels', 'gfm_markdown', 'publish_file_comments']:
return False
return True
def set_pr(self, pr_url: str):
self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr()
def get_file(self, path: str, commit_id: str):
file_content = ""
try:
file_content = self.bitbucket_client.get_content_of_file(self.workspace_slug,
self.repo_slug,
path,
commit_id)
except HTTPError as e:
get_logger().debug(f"File {path} not found at commit id: {commit_id}")
return file_content
def get_files(self):
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
diffstat = [change["path"]['toString'] for change in changes]
return diffstat
#gets the best common ancestor: https://git-scm.com/docs/git-merge-base
@staticmethod
def get_best_common_ancestor(source_commits_list, destination_commits_list, guaranteed_common_ancestor) -> str:
destination_commit_hashes = {commit['id'] for commit in destination_commits_list} | {guaranteed_common_ancestor}
for commit in source_commits_list:
for parent_commit in commit['parents']:
if parent_commit['id'] in destination_commit_hashes:
return parent_commit['id']
return guaranteed_common_ancestor
def get_diff_files(self) -> list[FilePatchInfo]:
if self.diff_files:
return self.diff_files
head_sha = self.pr.fromRef['latestCommit']
# if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("8.16"):
try:
base_sha = self.bitbucket_client.get(self._get_merge_base())['id']
except Exception as e:
get_logger().error(f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}")
raise e
else:
source_commits_list = list(self.bitbucket_client.get_pull_requests_commits(
self.workspace_slug,
self.repo_slug,
self.pr_num
))
# if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor
base_sha = source_commits_list[-1]['parents'][0]['id']
# if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("7.0"):
try:
destination_commits = list(
self.bitbucket_client.get_commits(self.workspace_slug, self.repo_slug, base_sha,
self.pr.toRef['latestCommit']))
base_sha = self.get_best_common_ancestor(source_commits_list, destination_commits, base_sha)
except Exception as e:
get_logger().error(
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}")
raise e
diff_files = []
original_file_content_str = ""
new_file_content_str = ""
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
for change in changes:
file_path = change['path']['toString']
if not is_valid_file(file_path.split("/")[-1]):
get_logger().info(f"Skipping a non-code file: {file_path}")
continue
match change['type']:
case 'ADD':
edit_type = EDIT_TYPE.ADDED
new_file_content_str = self.get_file(file_path, head_sha)
new_file_content_str = decode_if_bytes(new_file_content_str)
original_file_content_str = ""
case 'DELETE':
edit_type = EDIT_TYPE.DELETED
new_file_content_str = ""
original_file_content_str = self.get_file(file_path, base_sha)
original_file_content_str = decode_if_bytes(original_file_content_str)
case 'RENAME':
edit_type = EDIT_TYPE.RENAMED
case _:
edit_type = EDIT_TYPE.MODIFIED
original_file_content_str = self.get_file(file_path, base_sha)
original_file_content_str = decode_if_bytes(original_file_content_str)
new_file_content_str = self.get_file(file_path, head_sha)
new_file_content_str = decode_if_bytes(new_file_content_str)
patch = load_large_diff(file_path, new_file_content_str, original_file_content_str, show_warning=False)
diff_files.append(
FilePatchInfo(
original_file_content_str,
new_file_content_str,
patch,
file_path,
edit_type=edit_type,
)
)
self.diff_files = diff_files
return diff_files
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if not is_temporary:
self.bitbucket_client.add_pull_request_comment(self.workspace_slug, self.repo_slug, self.pr_num, pr_comment)
def remove_initial_comment(self):
try:
for comment in self.temp_comments:
self.remove_comment(comment)
except ValueError as e:
get_logger().exception(f"Failed to remove temp comments, error: {e}")
def remove_comment(self, comment):
pass
# function to create_inline_comment
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.get_diff_files(),
relevant_file.strip('`'),
relevant_line_in_file,
absolute_position
)
if position == -1:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
subject_type = "FILE"
else:
subject_type = "LINE"
path = relevant_file.strip()
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
payload = {
"text": comment,
"severity": "NORMAL",
"anchor": {
"diffType": "EFFECTIVE",
"path": file,
"lineType": "ADDED",
"line": from_line,
"fileType": "TO"
}
}
try:
self.bitbucket_client.post(self._get_pr_comments_path(), data=payload)
except Exception as e:
get_logger().error(f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}")
raise e
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
if relevant_line_start == -1:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}"
else:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={relevant_line_start}"
return link
def generate_link_to_relevant_line_number(self, suggestion) -> str:
try:
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
relevant_line_str = suggestion['relevant_line'].rstrip()
if not relevant_line_str:
return ""
diff_files = self.get_diff_files()
position, absolute_position = find_line_number_of_relevant_line_in_file \
(diff_files, relevant_file, relevant_line_str)
if absolute_position != -1:
if self.pr:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
return link
else:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}' since PR not set")
else:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}' since position not found")
if absolute_position != -1 and self.pr_url:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
return link
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}', error: {e}")
return ""
def publish_inline_comments(self, comments: list[dict]):
for comment in comments:
if 'position' in comment:
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
elif 'start_line' in comment: # multi-line comment
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
elif 'line' in comment: # single-line comment
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
else:
get_logger().error(f"Could not publish inline comment: {comment}")
def get_title(self):
return self.pr.title
def get_languages(self):
return {"yaml": 0} # devops LOL
def get_pr_branch(self):
return self.pr.fromRef['displayId']
def get_pr_owner_id(self) -> str | None:
return self.workspace_slug
def get_pr_description_full(self):
if hasattr(self.pr, "description"):
return self.pr.description
else:
return None
def get_user_id(self):
return 0
def get_issue_comments(self):
raise NotImplementedError(
"Bitbucket provider does not support issue comments yet"
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return True
@staticmethod
def _parse_bitbucket_server(url: str) -> str:
# pr url format: f"{bitbucket_server}/projects/{project_name}/repos/{repository_name}/pull-requests/{pr_id}"
parsed_url = urlparse(url)
server_path = parsed_url.path.split("/projects/")
if len(server_path) > 1:
server_path = server_path[0].strip("/")
return f"{parsed_url.scheme}://{parsed_url.netloc}/{server_path}".strip("/")
return f"{parsed_url.scheme}://{parsed_url.netloc}"
@staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, str, int]:
# pr url format: f"{bitbucket_server}/projects/{project_name}/repos/{repository_name}/pull-requests/{pr_id}"
parsed_url = urlparse(pr_url)
path_parts = parsed_url.path.strip("/").split("/")
try:
projects_index = path_parts.index("projects")
except ValueError:
projects_index = -1
try:
users_index = path_parts.index("users")
except ValueError:
users_index = -1
if projects_index == -1 and users_index == -1:
raise ValueError(f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL")
if projects_index != -1:
path_parts = path_parts[projects_index:]
else:
path_parts = path_parts[users_index:]
if len(path_parts) < 6 or path_parts[2] != "repos" or path_parts[4] != "pull-requests":
raise ValueError(
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
)
workspace_slug = path_parts[1]
if users_index != -1:
workspace_slug = f"~{workspace_slug}"
repo_slug = path_parts[3]
try:
pr_number = int(path_parts[5])
except ValueError as e:
raise ValueError(f"Unable to convert PR number '{path_parts[5]}' to integer") from e
return workspace_slug, repo_slug, pr_number
def _get_repo(self):
if self.repo is None:
self.repo = self.bitbucket_client.get_repo(self.workspace_slug, self.repo_slug)
return self.repo
def _get_pr(self):
try:
pr = self.bitbucket_client.get_pull_request(self.workspace_slug, self.repo_slug,
pull_request_id=self.pr_num)
return type('new_dict', (object,), pr)
except Exception as e:
get_logger().error(f"Failed to get pull request, error: {e}")
raise e
def _get_pr_file_content(self, remote_link: str):
return ""
def get_commit_messages(self):
return ""
# bitbucket does not support labels
def publish_description(self, pr_title: str, description: str):
payload = {
"version": self.pr.version,
"description": description,
"title": pr_title,
"reviewers": self.pr.reviewers # needs to be sent otherwise gets wiped
}
try:
self.bitbucket_client.update_pull_request(self.workspace_slug, self.repo_slug, str(self.pr_num), payload)
except Exception as e:
get_logger().error(f"Failed to update pull request, error: {e}")
raise e
# bitbucket does not support labels
def publish_labels(self, pr_types: list):
pass
# bitbucket does not support labels
def get_pr_labels(self, update=False):
pass
def _get_pr_comments_path(self):
return f"rest/api/latest/projects/{self.workspace_slug}/repos/{self.repo_slug}/pull-requests/{self.pr_num}/comments"
def _get_merge_base(self):
return f"rest/api/latest/projects/{self.workspace_slug}/repos/{self.repo_slug}/pull-requests/{self.pr_num}/merge-base"

View File

@ -0,0 +1,277 @@
import boto3
import botocore
class CodeCommitDifferencesResponse:
"""
CodeCommitDifferencesResponse is the response object returned from our get_differences() function.
It maps the JSON response to member variables of this class.
"""
def __init__(self, json: dict):
before_blob = json.get("beforeBlob", {})
after_blob = json.get("afterBlob", {})
self.before_blob_id = before_blob.get("blobId", "")
self.before_blob_path = before_blob.get("path", "")
self.after_blob_id = after_blob.get("blobId", "")
self.after_blob_path = after_blob.get("path", "")
self.change_type = json.get("changeType", "")
class CodeCommitPullRequestResponse:
"""
CodeCommitPullRequestResponse is the response object returned from our get_pr() function.
It maps the JSON response to member variables of this class.
"""
def __init__(self, json: dict):
self.title = json.get("title", "")
self.description = json.get("description", "")
self.targets = []
for target in json.get("pullRequestTargets", []):
self.targets.append(CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target))
class CodeCommitPullRequestTarget:
"""
CodeCommitPullRequestTarget is a subclass of CodeCommitPullRequestResponse that
holds details about an individual target commit.
"""
def __init__(self, json: dict):
self.source_commit = json.get("sourceCommit", "")
self.source_branch = json.get("sourceReference", "")
self.destination_commit = json.get("destinationCommit", "")
self.destination_branch = json.get("destinationReference", "")
class CodeCommitClient:
"""
CodeCommitClient is a wrapper around the AWS boto3 SDK for the CodeCommit client
"""
def __init__(self):
self.boto_client = None
def is_supported(self, capability: str) -> bool:
if capability in ["gfm_markdown"]:
return False
return True
def _connect_boto_client(self):
try:
self.boto_client = boto3.client("codecommit")
except Exception as e:
raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e
def get_differences(self, repo_name: int, destination_commit: str, source_commit: str):
"""
Get the differences between two commits in CodeCommit.
Args:
- repo_name: Name of the repository
- destination_commit: Commit hash you want to merge into (the "before" hash) (usually on the main or master branch)
- source_commit: Commit hash of the code you are adding (the "after" branch)
Returns:
- List of CodeCommitDifferencesResponse objects
Boto3 Documentation:
- aws codecommit get-differences
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/get_differences.html
"""
if self.boto_client is None:
self._connect_boto_client()
# The differences response from AWS is paginated, so we need to iterate through the pages to get all the differences.
differences = []
try:
paginator = self.boto_client.get_paginator("get_differences")
for page in paginator.paginate(
repositoryName=repo_name,
beforeCommitSpecifier=destination_commit,
afterCommitSpecifier=source_commit,
):
differences.extend(page.get("differences", []))
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}") from e
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
output = []
for json in differences:
output.append(CodeCommitDifferencesResponse(json))
return output
def get_file(self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False):
"""
Retrieve a file from CodeCommit.
Args:
- repo_name: Name of the repository
- file_path: Path to the file you are retrieving
- sha_hash: Commit hash of the file you are retrieving
Returns:
- File contents
Boto3 Documentation:
- aws codecommit get_file
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/get_file.html
"""
if not file_path:
return ""
if self.boto_client is None:
self._connect_boto_client()
try:
response = self.boto_client.get_file(repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
# if the file does not exist, but is flagged as optional, then return an empty string
if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException':
return ""
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
if "fileContent" not in response:
raise ValueError(f"File content is empty for file: {file_path}")
return response.get("fileContent", "")
def get_pr(self, repo_name: str, pr_number: int):
"""
Get a information about a CodeCommit PR.
Args:
- repo_name: Name of the repository
- pr_number: The PR number you are requesting
Returns:
- CodeCommitPullRequestResponse object
Boto3 Documentation:
- aws codecommit get_pull_request
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/get_pull_request.html
"""
if self.boto_client is None:
self._connect_boto_client()
try:
response = self.boto_client.get_pull_request(pullRequestId=str(pr_number))
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}") from e
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}: boto client error") from e
except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e
if "pullRequest" not in response:
raise ValueError("CodeCommit PR number not found: {pr_number}")
return CodeCommitPullRequestResponse(response.get("pullRequest", {}))
def publish_description(self, pr_number: int, pr_title: str, pr_body: str):
"""
Set the title and description on a pull request
Args:
- pr_number: the AWS CodeCommit pull request number
- pr_title: title of the pull request
- pr_body: body of the pull request
Returns:
- None
Boto3 Documentation:
- aws codecommit update_pull_request_title
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/update_pull_request_title.html
- aws codecommit update_pull_request_description
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/update_pull_request_description.html
"""
if self.boto_client is None:
self._connect_boto_client()
try:
self.boto_client.update_pull_request_title(pullRequestId=str(pr_number), title=pr_title)
self.boto_client.update_pull_request_description(pullRequestId=str(pr_number), description=pr_body)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"PR number does not exist: {pr_number}") from e
if e.response["Error"]["Code"] == 'InvalidTitleException':
raise ValueError(f"Invalid title for PR number: {pr_number}") from e
if e.response["Error"]["Code"] == 'InvalidDescriptionException':
raise ValueError(f"Invalid description for PR number: {pr_number}") from e
if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException':
raise ValueError(f"PR is already closed: PR number: {pr_number}") from e
raise ValueError(f"Boto3 client error calling publish_description") from e
except Exception as e:
raise ValueError(f"Error calling publish_description") from e
def publish_comment(self, repo_name: str, pr_number: int, destination_commit: str, source_commit: str, comment: str, annotation_file: str = None, annotation_line: int = None):
"""
Publish a comment to a pull request
Args:
- repo_name: name of the repository
- pr_number: number of the pull request
- destination_commit: The commit hash you want to merge into (the "before" hash) (usually on the main or master branch)
- source_commit: The commit hash of the code you are adding (the "after" branch)
- comment: The comment you want to publish
- annotation_file: The file you want to annotate (optional)
- annotation_line: The line number you want to annotate (optional)
Comment annotations for CodeCommit are different than GitHub.
CodeCommit only designates the starting line number for the comment.
It does not support the ending line number to highlight a range of lines.
Returns:
- None
Boto3 Documentation:
- aws codecommit post_comment_for_pull_request
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_pull_request.html
"""
if self.boto_client is None:
self._connect_boto_client()
try:
# If the comment has code annotations,
# then set the file path and line number in the location dictionary
if annotation_file and annotation_line:
self.boto_client.post_comment_for_pull_request(
pullRequestId=str(pr_number),
repositoryName=repo_name,
beforeCommitId=destination_commit,
afterCommitId=source_commit,
content=comment,
location={
"filePath": annotation_file,
"filePosition": annotation_line,
"relativeFileVersion": "AFTER",
},
)
else:
# The comment does not have code annotations
self.boto_client.post_comment_for_pull_request(
pullRequestId=str(pr_number),
repositoryName=repo_name,
beforeCommitId=destination_commit,
afterCommitId=source_commit,
content=comment,
)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"Repository does not exist: {repo_name}") from e
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"PR number does not exist: {pr_number}") from e
raise ValueError(f"Boto3 client error calling post_comment_for_pull_request") from e
except Exception as e:
raise ValueError(f"Error calling post_comment_for_pull_request") from e

View File

@ -0,0 +1,497 @@
import os
import re
from collections import Counter
from typing import List, Optional, Tuple
from urllib.parse import urlparse
from utils.pr_agent.algo.language_handler import is_valid_file
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from utils.pr_agent.git_providers.codecommit_client import CodeCommitClient
from ..algo.utils import load_large_diff
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import GitProvider
class PullRequestCCMimic:
"""
This class mimics the PullRequest class from the PyGithub library for the CodeCommitProvider.
"""
def __init__(self, title: str, diff_files: List[FilePatchInfo]):
self.title = title
self.diff_files = diff_files
self.description = None
self.source_commit = None
self.source_branch = None # the branch containing your new code changes
self.destination_commit = None
self.destination_branch = None # the branch you are going to merge into
class CodeCommitFile:
"""
This class represents a file in a pull request in CodeCommit.
"""
def __init__(
self,
a_path: str,
a_blob_id: str,
b_path: str,
b_blob_id: str,
edit_type: EDIT_TYPE,
):
self.a_path = a_path
self.a_blob_id = a_blob_id
self.b_path = b_path
self.b_blob_id = b_blob_id
self.edit_type: EDIT_TYPE = edit_type
self.filename = b_path if b_path else a_path
class CodeCommitProvider(GitProvider):
"""
This class implements the GitProvider interface for AWS CodeCommit repositories.
"""
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
self.codecommit_client = CodeCommitClient()
self.aws_client = None
self.repo_name = None
self.pr_num = None
self.pr = None
self.diff_files = None
self.git_files = None
self.pr_url = pr_url
if pr_url:
self.set_pr(pr_url)
def provider_name(self):
return "CodeCommit"
def is_supported(self, capability: str) -> bool:
if capability in [
"get_issue_comments",
"create_inline_comment",
"publish_inline_comments",
"get_labels",
"gfm_markdown"
]:
return False
return True
def set_pr(self, pr_url: str):
self.repo_name, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr()
def get_files(self) -> list[CodeCommitFile]:
# bring files from CodeCommit only once
if self.git_files:
return self.git_files
self.git_files = []
differences = self.codecommit_client.get_differences(self.repo_name, self.pr.destination_commit, self.pr.source_commit)
for item in differences:
self.git_files.append(CodeCommitFile(item.before_blob_path,
item.before_blob_id,
item.after_blob_path,
item.after_blob_id,
CodeCommitProvider._get_edit_type(item.change_type)))
return self.git_files
def get_diff_files(self) -> list[FilePatchInfo]:
"""
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in CodeCommit,
along with their content and patch information.
Returns:
diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted,
or renamed files in the merge request.
"""
# bring files from CodeCommit only once
if self.diff_files:
return self.diff_files
self.diff_files = []
files = self.get_files()
for diff_item in files:
patch_filename = ""
if diff_item.a_blob_id is not None:
patch_filename = diff_item.a_path
original_file_content_str = self.codecommit_client.get_file(
self.repo_name, diff_item.a_path, self.pr.destination_commit)
if isinstance(original_file_content_str, (bytes, bytearray)):
original_file_content_str = original_file_content_str.decode("utf-8")
else:
original_file_content_str = ""
if diff_item.b_blob_id is not None:
patch_filename = diff_item.b_path
new_file_content_str = self.codecommit_client.get_file(self.repo_name, diff_item.b_path, self.pr.source_commit)
if isinstance(new_file_content_str, (bytes, bytearray)):
new_file_content_str = new_file_content_str.decode("utf-8")
else:
new_file_content_str = ""
patch = load_large_diff(patch_filename, new_file_content_str, original_file_content_str)
# Store the diffs as a list of FilePatchInfo objects
info = FilePatchInfo(
original_file_content_str,
new_file_content_str,
patch,
diff_item.b_path,
edit_type=diff_item.edit_type,
old_filename=None
if diff_item.a_path == diff_item.b_path
else diff_item.a_path,
)
# Only add valid files to the diff list
# "bad extensions" are set in the language_extensions.toml file
# a "valid file" is one that is not in the "bad extensions" list
if is_valid_file(info.filename):
self.diff_files.append(info)
return self.diff_files
def publish_description(self, pr_title: str, pr_body: str):
try:
self.codecommit_client.publish_description(
pr_number=self.pr_num,
pr_title=pr_title,
pr_body=CodeCommitProvider._add_additional_newlines(pr_body),
)
except Exception as e:
raise ValueError(f"CodeCommit Cannot publish description for PR: {self.pr_num}") from e
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary:
get_logger().info(pr_comment)
return
pr_comment = CodeCommitProvider._remove_markdown_html(pr_comment)
pr_comment = CodeCommitProvider._add_additional_newlines(pr_comment)
try:
self.codecommit_client.publish_comment(
repo_name=self.repo_name,
pr_number=self.pr_num,
destination_commit=self.pr.destination_commit,
source_commit=self.pr.source_commit,
comment=pr_comment,
)
except Exception as e:
raise ValueError(f"CodeCommit Cannot publish comment for PR: {self.pr_num}") from e
def publish_code_suggestions(self, code_suggestions: list) -> bool:
counter = 1
for suggestion in code_suggestions:
# Verify that each suggestion has the required keys
if not all(key in suggestion for key in ["body", "relevant_file", "relevant_lines_start"]):
get_logger().warning(f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys")
continue
# Publish the code suggestion to CodeCommit
try:
get_logger().debug(f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}")
self.codecommit_client.publish_comment(
repo_name=self.repo_name,
pr_number=self.pr_num,
destination_commit=self.pr.destination_commit,
source_commit=self.pr.source_commit,
comment=suggestion["body"],
annotation_file=suggestion["relevant_file"],
annotation_line=suggestion["relevant_lines_start"],
)
except Exception as e:
raise ValueError(f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}") from e
counter += 1
# The calling function passes in a list of code suggestions, and this function publishes each suggestion one at a time.
# If we were to return False here, the calling function will attempt to publish the same list of code suggestions again, one at a time.
# Since this function publishes the suggestions one at a time anyway, we always return True here to avoid the retry.
return True
def publish_labels(self, labels):
return [""] # not implemented yet
def get_pr_labels(self, update=False):
return [""] # not implemented yet
def remove_initial_comment(self):
return "" # not implemented yet
def remove_comment(self, comment):
return "" # not implemented yet
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
def get_title(self):
return self.pr.title
def get_pr_id(self):
"""
Returns the PR ID in the format: "repo_name/pr_number".
Note: This is an internal identifier for PR-Agent,
and is not the same as the CodeCommit PR identifier.
"""
try:
pr_id = f"{self.repo_name}/{self.pr_num}"
return pr_id
except:
return ""
def get_languages(self):
"""
Returns a dictionary of languages, containing the percentage of each language used in the PR.
Returns:
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
"""
commit_files = self.get_files()
filenames = [ item.filename for item in commit_files ]
extensions = CodeCommitProvider._get_file_extensions(filenames)
# Calculate the percentage of each file extension in the PR
percentages = CodeCommitProvider._get_language_percentages(extensions)
# The global language_extension_map is a dictionary of languages,
# where each dictionary item is a BoxList of extensions.
# We want a dictionary of extensions,
# where each dictionary item is a language name.
# We build that language->extension dictionary here in main_extensions_flat.
main_extensions_flat = {}
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
for language, extensions in language_extension_map.items():
for ext in extensions:
main_extensions_flat[ext] = language
# Map the file extension/languages to percentages
languages = {}
for ext, pct in percentages.items():
languages[main_extensions_flat.get(ext, "")] = pct
return languages
def get_pr_branch(self):
return self.pr.source_branch
def get_pr_description_full(self) -> str:
return self.pr.description
def get_user_id(self):
return -1 # not implemented yet
def get_issue_comments(self):
raise NotImplementedError("CodeCommit provider does not support issue comments yet")
def get_repo_settings(self):
# a local ".pr_agent.toml" settings file is optional
settings_filename = ".pr_agent.toml"
return self.codecommit_client.get_file(self.repo_name, settings_filename, self.pr.source_commit, optional=True)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
get_logger().info("CodeCommit provider does not support eyes reaction yet")
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
get_logger().info("CodeCommit provider does not support removing reactions yet")
return True
@staticmethod
def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
"""
Parse the CodeCommit PR URL and return the repository name and PR number.
Args:
- pr_url: the full AWS CodeCommit pull request URL
Returns:
- Tuple[str, int]: A tuple containing the repository name and PR number.
"""
# Example PR URL:
# https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/__MY_REPO__/pull-requests/123456"
parsed_url = urlparse(pr_url)
if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc):
raise ValueError(f"The provided URL is not a valid CodeCommit URL: {pr_url}")
path_parts = parsed_url.path.strip("/").split("/")
if (
len(path_parts) < 6
or path_parts[0] != "codesuite"
or path_parts[1] != "codecommit"
or path_parts[2] != "repositories"
or path_parts[4] != "pull-requests"
):
raise ValueError(f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}")
repo_name = path_parts[3]
try:
pr_number = int(path_parts[5])
except ValueError as e:
raise ValueError(f"Unable to convert PR number to integer: '{path_parts[5]}'") from e
return repo_name, pr_number
@staticmethod
def _is_valid_codecommit_hostname(hostname: str) -> bool:
"""
Check if the provided hostname is a valid AWS CodeCommit hostname.
This is not an exhaustive check of AWS region names,
but instead uses a regex to check for matching AWS region patterns.
Args:
- hostname: the hostname to check
Returns:
- bool: True if the hostname is valid, False otherwise.
"""
return re.match(r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname) is not None
def _get_pr(self):
response = self.codecommit_client.get_pr(self.repo_name, self.pr_num)
if len(response.targets) == 0:
raise ValueError(f"No files found in CodeCommit PR: {self.pr_num}")
# TODO: implement support for multiple targets in one CodeCommit PR
# for now, we are only using the first target in the PR
if len(response.targets) > 1:
get_logger().warning(
"Multiple targets in one PR is not supported for CodeCommit yet. Continuing, using the first target only..."
)
# Return our object that mimics PullRequest class from the PyGithub library
# (This strategy was copied from the LocalGitProvider)
mimic = PullRequestCCMimic(response.title, self.diff_files)
mimic.description = response.description
mimic.source_commit = response.targets[0].source_commit
mimic.source_branch = response.targets[0].source_branch
mimic.destination_commit = response.targets[0].destination_commit
mimic.destination_branch = response.targets[0].destination_branch
return mimic
def get_commit_messages(self):
return "" # not implemented yet
@staticmethod
def _add_additional_newlines(body: str) -> str:
"""
Replace single newlines in a PR body with double newlines.
CodeCommit Markdown does not seem to render as well as GitHub Markdown,
so we add additional newlines to the PR body to make it more readable in CodeCommit.
Args:
- body: the PR body
Returns:
- str: the PR body with the double newlines added
"""
return re.sub(r'(?<!\n)\n(?!\n)', '\n\n', body)
@staticmethod
def _remove_markdown_html(comment: str) -> str:
"""
Remove the HTML tags from a PR comment.
CodeCommit Markdown does not seem to render as well as GitHub Markdown,
so we remove the HTML tags from the PR comment to make it more readable in CodeCommit.
Args:
- comment: the PR comment
Returns:
- str: the PR comment with the HTML tags removed
"""
comment = comment.replace("<details>", "")
comment = comment.replace("</details>", "")
comment = comment.replace("<summary>", "")
comment = comment.replace("</summary>", "")
return comment
@staticmethod
def _get_edit_type(codecommit_change_type: str):
"""
Convert the CodeCommit change type string to the EDIT_TYPE enum.
The CodeCommit change type string is returned from the get_differences SDK method.
Args:
- codecommit_change_type: the CodeCommit change type string
Returns:
- An EDIT_TYPE enum representing the modified, added, deleted, or renamed file in the PR diff.
"""
t = codecommit_change_type.upper()
edit_type = None
if t == "A":
edit_type = EDIT_TYPE.ADDED
elif t == "D":
edit_type = EDIT_TYPE.DELETED
elif t == "M":
edit_type = EDIT_TYPE.MODIFIED
elif t == "R":
edit_type = EDIT_TYPE.RENAMED
return edit_type
@staticmethod
def _get_file_extensions(filenames):
"""
Return a list of file extensions from a list of filenames.
The returned extensions will include the dot "." prefix,
to accommodate for the dots in the existing language_extension_map settings.
Filenames with no extension will return an empty string for the extension.
Args:
- filenames: a list of filenames
Returns:
- list: A list of file extensions, including the dot "." prefix.
"""
extensions = []
for filename in filenames:
filename, ext = os.path.splitext(filename)
if ext:
extensions.append(ext.lower())
else:
extensions.append("")
return extensions
@staticmethod
def _get_language_percentages(extensions):
"""
Return a dictionary containing the programming language name (as the key),
and the percentage that language is used (as the value),
given a list of file extensions.
Args:
- extensions: a list of file extensions
Returns:
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
"""
total_files = len(extensions)
if total_files == 0:
return {}
# Identify language by file extension and count
lang_count = Counter(extensions)
# Convert counts to percentages
lang_percentage = {
lang: round(count / total_files * 100) for lang, count in lang_count.items()
}
return lang_percentage

View File

@ -0,0 +1,399 @@
import json
import os
import pathlib
import shutil
import subprocess
import uuid
from collections import Counter, namedtuple
from pathlib import Path
from tempfile import NamedTemporaryFile, mkdtemp
import requests
import urllib3.util
from git import Repo
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.git_provider import GitProvider
from utils.pr_agent.git_providers.local_git_provider import PullRequestMimic
from utils.pr_agent.log import get_logger
def _call(*command, **kwargs) -> (int, str, str):
res = subprocess.run(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
**kwargs,
)
return res.stdout.decode()
def clone(url, directory):
get_logger().info("Cloning %s to %s", url, directory)
stdout = _call('git', 'clone', "--depth", "1", url, directory)
get_logger().info(stdout)
def fetch(url, refspec, cwd):
get_logger().info("Fetching %s %s", url, refspec)
stdout = _call(
'git', 'fetch', '--depth', '2', url, refspec,
cwd=cwd
)
get_logger().info(stdout)
def checkout(cwd):
get_logger().info("Checking out")
stdout = _call('git', 'checkout', "FETCH_HEAD", cwd=cwd)
get_logger().info(stdout)
def show(*args, cwd=None):
get_logger().info("Show")
return _call('git', 'show', *args, cwd=cwd)
def diff(*args, cwd=None):
get_logger().info("Diff")
patch = _call('git', 'diff', *args, cwd=cwd)
if not patch:
get_logger().warning("No changes found")
return
return patch
def reset_local_changes(cwd):
get_logger().info("Reset local changes")
_call('git', 'checkout', "--force", cwd=cwd)
def add_comment(url: urllib3.util.Url, refspec, message):
*_, patchset, changenum = refspec.rsplit("/")
message = "'" + message.replace("'", "'\"'\"'") + "'"
return _call(
"ssh",
"-p", str(url.port),
f"{url.auth}@{url.host}",
"gerrit", "review",
"--message", message,
# "--code-review", score,
f"{patchset},{changenum}",
)
def list_comments(url: urllib3.util.Url, refspec):
*_, patchset, _ = refspec.rsplit("/")
stdout = _call(
"ssh",
"-p", str(url.port),
f"{url.auth}@{url.host}",
"gerrit", "query",
"--comments",
"--current-patch-set", patchset,
"--format", "JSON",
)
change_set, *_ = stdout.splitlines()
return json.loads(change_set)["currentPatchSet"]["comments"]
def prepare_repo(url: urllib3.util.Url, project, refspec):
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}")
directory = pathlib.Path(mkdtemp())
clone(repo_url, directory),
fetch(repo_url, refspec, cwd=directory)
checkout(cwd=directory)
return directory
def adopt_to_gerrit_message(message):
lines = message.splitlines()
buf = []
for line in lines:
# remove markdown formatting
line = (line.replace("*", "")
.replace("``", "`")
.replace("<details>", "")
.replace("</details>", "")
.replace("<summary>", "")
.replace("</summary>", ""))
line = line.strip()
if line.startswith('#'):
buf.append("\n" +
line.replace('#', '').removesuffix(":").strip() +
":")
continue
elif line.startswith('-'):
buf.append(line.removeprefix('-').strip())
continue
else:
buf.append(line)
return "\n".join(buf).strip()
def add_suggestion(src_filename, context: str, start, end: int):
with (
NamedTemporaryFile("w", delete=False) as tmp,
open(src_filename, "r") as src
):
lines = src.readlines()
tmp.writelines(lines[:start - 1])
if context:
tmp.write(context)
tmp.writelines(lines[end:])
shutil.copy(tmp.name, src_filename)
os.remove(tmp.name)
def upload_patch(patch, path):
patch_server_endpoint = get_settings().get(
'gerrit.patch_server_endpoint')
patch_server_token = get_settings().get(
'gerrit.patch_server_token')
response = requests.post(
patch_server_endpoint,
json={
"content": patch,
"path": path,
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {patch_server_token}",
}
)
response.raise_for_status()
patch_server_endpoint = patch_server_endpoint.rstrip("/")
return patch_server_endpoint + "/" + path
class GerritProvider(GitProvider):
def __init__(self, key: str, incremental=False):
self.project, self.refspec = key.split(':')
assert self.project, "Project name is required"
assert self.refspec, "Refspec is required"
base_url = get_settings().get('gerrit.url')
assert base_url, "Gerrit URL is required"
user = get_settings().get('gerrit.user')
assert user, "Gerrit user is required"
parsed = urllib3.util.parse_url(base_url)
self.parsed_url = urllib3.util.parse_url(
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
)
self.repo_path = prepare_repo(
self.parsed_url, self.project, self.refspec
)
self.repo = Repo(self.repo_path)
assert self.repo
self.pr_url = base_url
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
def get_pr_title(self):
"""
Substitutes the branch-name as the PR-mimic title.
"""
return self.repo.branches[0].name
def get_issue_comments(self):
comments = list_comments(self.parsed_url, self.refspec)
Comments = namedtuple('Comments', ['reversed'])
Comment = namedtuple('Comment', ['body'])
return Comments([Comment(c['message']) for c in reversed(comments)])
def get_pr_labels(self, update=False):
raise NotImplementedError(
'Getting labels is not implemented for the gerrit provider')
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
raise NotImplementedError(
'Adding reactions is not implemented for the gerrit provider')
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
raise NotImplementedError(
'Removing reactions is not implemented for the gerrit provider')
def get_commit_messages(self):
return [self.repo.head.commit.message]
def get_repo_settings(self):
try:
with open(self.repo_path / ".pr_agent.toml", 'rb') as f:
contents = f.read()
return contents
except OSError:
return b""
def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.repo.head.commit.diff(
self.repo.head.commit.parents[0], # previous commit
create_patch=True,
R=True
)
diff_files = []
for diff_item in diffs:
if diff_item.a_blob is not None:
original_file_content_str = (
diff_item.a_blob.data_stream.read().decode('utf-8')
)
else:
original_file_content_str = "" # empty file
if diff_item.b_blob is not None:
new_file_content_str = diff_item.b_blob.data_stream.read(). \
decode('utf-8')
else:
new_file_content_str = "" # empty file
edit_type = EDIT_TYPE.MODIFIED
if diff_item.new_file:
edit_type = EDIT_TYPE.ADDED
elif diff_item.deleted_file:
edit_type = EDIT_TYPE.DELETED
elif diff_item.renamed_file:
edit_type = EDIT_TYPE.RENAMED
diff_files.append(
FilePatchInfo(
original_file_content_str,
new_file_content_str,
diff_item.diff.decode('utf-8'),
diff_item.b_path,
edit_type=edit_type,
old_filename=None
if diff_item.a_path == diff_item.b_path
else diff_item.a_path
)
)
self.diff_files = diff_files
return diff_files
def get_files(self):
diff_index = self.repo.head.commit.diff(
self.repo.head.commit.parents[0], # previous commit
R=True
)
# Get the list of changed files
diff_files = [item.a_path for item in diff_index]
return diff_files
def get_languages(self):
"""
Calculate percentage of languages in repository. Used for hunk
prioritisation.
"""
# Get all files in repository
filepaths = [Path(item.path) for item in
self.repo.tree().traverse() if item.type == 'blob']
# Identify language by file extension and count
lang_count = Counter(
ext.lstrip('.') for filepath in filepaths for ext in
[filepath.suffix.lower()])
# Convert counts to percentages
total_files = len(filepaths)
lang_percentage = {lang: count / total_files * 100 for lang, count
in lang_count.items()}
return lang_percentage
def get_pr_description_full(self):
return self.repo.head.commit.message
def get_user_id(self):
return self.repo.head.commit.author.email
def is_supported(self, capability: str) -> bool:
if capability in [
# 'get_issue_comments',
'create_inline_comment',
'publish_inline_comments',
'get_labels',
'gfm_markdown'
]:
return False
return True
def split_suggestion(self, msg) -> tuple[str, str]:
is_code_context = False
description = []
context = []
for line in msg.splitlines():
if line.startswith('```suggestion'):
is_code_context = True
continue
if line.startswith('```'):
is_code_context = False
continue
if is_code_context:
context.append(line)
else:
description.append(
line.replace('*', '')
)
return (
'\n'.join(description),
'\n'.join(context) + '\n' if context else ''
)
def publish_code_suggestions(self, code_suggestions: list):
msg = []
for suggestion in code_suggestions:
description, code = self.split_suggestion(suggestion['body'])
add_suggestion(
pathlib.Path(self.repo_path) / suggestion["relevant_file"],
code,
suggestion["relevant_lines_start"],
suggestion["relevant_lines_end"],
)
patch = diff(cwd=self.repo_path)
patch_id = uuid.uuid4().hex[0:4]
path = "/".join(["codium-ai", self.refspec, patch_id])
full_path = upload_patch(patch, path)
reset_local_changes(self.repo_path)
msg.append(f'* {description}\n{full_path}')
if msg:
add_comment(self.parsed_url, self.refspec, "\n".join(msg))
return True
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if not is_temporary:
msg = adopt_to_gerrit_message(pr_comment)
add_comment(self.parsed_url, self.refspec, msg)
def publish_description(self, pr_title: str, pr_body: str):
msg = adopt_to_gerrit_message(pr_body)
add_comment(self.parsed_url, self.refspec, pr_title + '\n' + msg)
def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError(
'Publishing inline comments is not implemented for the gerrit '
'provider')
def publish_inline_comment(self, body: str, relevant_file: str,
relevant_line_in_file: str, original_suggestion=None):
raise NotImplementedError(
'Publishing inline comments is not implemented for the gerrit '
'provider')
def publish_labels(self, labels):
# Not applicable to the local git provider,
# but required by the interface
pass
def remove_initial_comment(self):
# remove repo, cloned in previous steps
# shutil.rmtree(self.repo_path)
pass
def remove_comment(self, comment):
pass
def get_pr_branch(self):
return self.repo.head

View File

@ -0,0 +1,350 @@
from abc import ABC, abstractmethod
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from typing import Optional
from utils.pr_agent.algo.types import FilePatchInfo
from utils.pr_agent.algo.utils import Range, process_description
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
MAX_FILES_ALLOWED_FULL = 50
class GitProvider(ABC):
@abstractmethod
def is_supported(self, capability: str) -> bool:
pass
@abstractmethod
def get_files(self) -> list:
pass
@abstractmethod
def get_diff_files(self) -> list[FilePatchInfo]:
pass
def get_incremental_commits(self, is_incremental):
pass
@abstractmethod
def publish_description(self, pr_title: str, pr_body: str):
pass
@abstractmethod
def publish_code_suggestions(self, code_suggestions: list) -> bool:
pass
@abstractmethod
def get_languages(self):
pass
@abstractmethod
def get_pr_branch(self):
pass
@abstractmethod
def get_user_id(self):
pass
@abstractmethod
def get_pr_description_full(self) -> str:
pass
def edit_comment(self, comment, body: str):
pass
def edit_comment_from_comment_id(self, comment_id: int, body: str):
pass
def get_comment_body_from_comment_id(self, comment_id: int) -> str:
pass
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
pass
def get_pr_description(self, full: bool = True, split_changes_walkthrough=False) -> str or tuple:
from utils.pr_agent.algo.utils import clip_tokens
from utils.pr_agent.config_loader import get_settings
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full() if full else self.get_user_description()
if split_changes_walkthrough:
description, files = process_description(description)
if max_tokens_description:
description = clip_tokens(description, max_tokens_description)
return description, files
else:
if max_tokens_description:
description = clip_tokens(description, max_tokens_description)
return description
def get_user_description(self) -> str:
if hasattr(self, 'user_description') and not (self.user_description is None):
return self.user_description
description = (self.get_pr_description_full() or "").strip()
description_lowercase = description.lower()
get_logger().debug(f"Existing description", description=description_lowercase)
# if the existing description wasn't generated by the pr-agent, just return it as-is
if not self._is_generated_by_pr_agent(description_lowercase):
get_logger().info(f"Existing description was not generated by the pr-agent")
self.user_description = description
return description
# if the existing description was generated by the pr-agent, but it doesn't contain a user description,
# return nothing (empty string) because it means there is no user description
user_description_header = "### **user description**"
if user_description_header not in description_lowercase:
get_logger().info(f"Existing description was generated by the pr-agent, but it doesn't contain a user description")
return ""
# otherwise, extract the original user description from the existing pr-agent description and return it
# user_description_start_position = description_lowercase.find(user_description_header) + len(user_description_header)
# return description[user_description_start_position:].split("\n", 1)[-1].strip()
# the 'user description' is in the beginning. extract and return it
possible_headers = self._possible_headers()
start_position = description_lowercase.find(user_description_header) + len(user_description_header)
end_position = len(description)
for header in possible_headers: # try to clip at the next header
if header != user_description_header and header in description_lowercase:
end_position = min(end_position, description_lowercase.find(header))
if end_position != len(description) and end_position > start_position:
original_user_description = description[start_position:end_position].strip()
if original_user_description.endswith("___"):
original_user_description = original_user_description[:-3].strip()
else:
original_user_description = description.split("___")[0].strip()
if original_user_description.lower().startswith(user_description_header):
original_user_description = original_user_description[len(user_description_header):].strip()
get_logger().info(f"Extracted user description from existing description",
description=original_user_description)
self.user_description = original_user_description
return original_user_description
def _possible_headers(self):
return ("### **user description**", "### **pr type**", "### **pr description**", "### **pr labels**", "### **type**", "### **description**",
"### **labels**", "### 🤖 generated by pr agent")
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
possible_headers = self._possible_headers()
return any(description_lowercase.startswith(header) for header in possible_headers)
@abstractmethod
def get_repo_settings(self):
pass
def get_workspace_name(self):
return ""
def get_pr_id(self):
return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
return ""
def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str:
return ""
#### comments operations ####
@abstractmethod
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass
def publish_persistent_comment(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
self.publish_comment(pr_comment)
def publish_persistent_comment_full(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
try:
prev_comments = list(self.get_issue_comments())
for comment in prev_comments:
if comment.body.startswith(initial_header):
latest_commit_url = self.get_latest_commit_url()
comment_url = self.get_comment_url(comment)
if update_header:
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
else:
pr_comment_updated = pr_comment
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
# response = self.mr.notes.update(comment.id, {'body': pr_comment_updated})
self.edit_comment(comment, pr_comment_updated)
if final_update_message:
self.publish_comment(
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
return
except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}")
pass
self.publish_comment(pr_comment)
@abstractmethod
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
pass
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
absolute_position: int = None):
raise NotImplementedError("This git provider does not support creating inline comments yet")
@abstractmethod
def publish_inline_comments(self, comments: list[dict]):
pass
@abstractmethod
def remove_initial_comment(self):
pass
@abstractmethod
def remove_comment(self, comment):
pass
@abstractmethod
def get_issue_comments(self):
pass
def get_comment_url(self, comment) -> str:
return ""
#### labels operations ####
@abstractmethod
def publish_labels(self, labels):
pass
@abstractmethod
def get_pr_labels(self, update=False):
pass
def get_repo_labels(self):
pass
@abstractmethod
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
pass
@abstractmethod
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
pass
#### commits operations ####
@abstractmethod
def get_commit_messages(self):
pass
def get_pr_url(self) -> str:
if hasattr(self, 'pr_url'):
return self.pr_url
return ""
def get_latest_commit_url(self) -> str:
return ""
def auto_approve(self) -> bool:
return False
def calc_pr_statistics(self, pull_request_data: dict):
return {}
def get_num_of_files(self):
try:
return len(self.get_diff_files())
except Exception as e:
return -1
def limit_output_characters(self, output: str, max_chars: int):
return output[:max_chars] + '...' if len(output) > max_chars else output
def get_main_pr_language(languages, files) -> str:
"""
Get the main language of the commit. Return an empty string if cannot determine.
"""
main_language_str = ""
if not languages:
get_logger().info("No languages detected")
return main_language_str
if not files:
get_logger().info("No files in diff")
return main_language_str
try:
top_language = max(languages, key=languages.get).lower()
# validate that the specific commit uses the main language
extension_list = []
for file in files:
if not file:
continue
if isinstance(file, str):
file = FilePatchInfo(base_file=None, head_file=None, patch=None, filename=file)
extension_list.append(file.filename.rsplit('.')[-1])
# get the most common extension
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
try:
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]:
main_language_str = top_language
else:
for language, extensions in language_extension_map.items():
if most_common_extension in extensions:
main_language_str = language
break
except Exception as e:
get_logger().exception(f"Failed to get main language: {e}")
pass
## old approach:
# most_common_extension = max(set(extension_list), key=extension_list.count)
# if most_common_extension == 'py' and top_language == 'python' or \
# most_common_extension == 'js' and top_language == 'javascript' or \
# most_common_extension == 'ts' and top_language == 'typescript' or \
# most_common_extension == 'tsx' and top_language == 'typescript' or \
# most_common_extension == 'go' and top_language == 'go' or \
# most_common_extension == 'java' and top_language == 'java' or \
# most_common_extension == 'c' and top_language == 'c' or \
# most_common_extension == 'cpp' and top_language == 'c++' or \
# most_common_extension == 'cs' and top_language == 'c#' or \
# most_common_extension == 'swift' and top_language == 'swift' or \
# most_common_extension == 'php' and top_language == 'php' or \
# most_common_extension == 'rb' and top_language == 'ruby' or \
# most_common_extension == 'rs' and top_language == 'rust' or \
# most_common_extension == 'scala' and top_language == 'scala' or \
# most_common_extension == 'kt' and top_language == 'kotlin' or \
# most_common_extension == 'pl' and top_language == 'perl' or \
# most_common_extension == top_language:
# main_language_str = top_language
except Exception as e:
get_logger().exception(e)
pass
return main_language_str
class IncrementalPR:
def __init__(self, is_incremental: bool = False):
self.is_incremental = is_incremental
self.commits_range = None
self.first_new_commit = None
self.last_seen_commit = None
@property
def first_new_commit_sha(self):
return None if self.first_new_commit is None else self.first_new_commit.sha
@property
def last_seen_commit_sha(self):
return None if self.last_seen_commit is None else self.last_seen_commit.sha

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,591 @@
import difflib
import re
from typing import Optional, Tuple
from urllib.parse import urlparse
import gitlab
from gitlab import GitlabGetError
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file
from ..algo.utils import (clip_tokens,
find_line_number_of_relevant_line_in_file,
load_large_diff)
from ..config_loader import get_settings
from ..log import get_logger
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
class DiffNotFoundError(Exception):
"""Raised when the diff for a merge request cannot be found."""
pass
class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GitLab URL is not set in the config file")
self.gitlab_url = gitlab_url
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab(
url=gitlab_url,
oauth_token=gitlab_access_token
)
self.max_comment_chars = 65000
self.id_project = None
self.id_mr = None
self.mr = None
self.diff_files = None
self.git_files = None
self.temp_comments = []
self.pr_url = merge_request_url
self._set_merge_request(merge_request_url)
self.RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
self.incremental = incremental
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments',
'publish_file_comments']: # gfm_markdown is supported in gitlab !
return False
return True
@property
def pr(self):
'''The GitLab terminology is merge request (MR) instead of pull request (PR)'''
return self.mr
def _set_merge_request(self, merge_request_url: str):
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request()
try:
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
except IndexError as e:
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") from e
def get_pr_file_content(self, file_path: str, branch: str) -> str:
try:
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
except GitlabGetError:
# In case of file creation the method returns GitlabGetError (404 file not found).
# In this case we return an empty string for the diff.
return ''
def get_diff_files(self) -> list[FilePatchInfo]:
"""
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitLab,
along with their content and patch information.
Returns:
diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted,
or renamed files in the merge request.
"""
if self.diff_files:
return self.diff_files
# filter files using [ignore] patterns
diffs_original = self.mr.changes()['changes']
diffs = filter_ignored(diffs_original, 'gitlab')
if diffs != diffs_original:
try:
names_original = [diff['new_path'] for diff in diffs_original]
names_filtered = [diff['new_path'] for diff in diffs]
get_logger().info(f"Filtered out [ignore] files for merge request {self.id_mr}", extra={
'original_files': names_original,
'filtered_files': names_filtered
})
except Exception as e:
pass
diff_files = []
invalid_files_names = []
counter_valid = 0
for diff in diffs:
if not is_valid_file(diff['new_path']):
invalid_files_names.append(diff['new_path'])
continue
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
counter_valid += 1
if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']:
original_file_content_str = self.get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha'])
new_file_content_str = self.get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha'])
else:
if counter_valid == MAX_FILES_ALLOWED_FULL:
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
original_file_content_str = ''
new_file_content_str = ''
try:
if isinstance(original_file_content_str, bytes):
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
if isinstance(new_file_content_str, bytes):
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
except UnicodeDecodeError:
get_logger().warning(
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']:
edit_type = EDIT_TYPE.ADDED
elif diff['deleted_file']:
edit_type = EDIT_TYPE.DELETED
elif diff['renamed_file']:
edit_type = EDIT_TYPE.RENAMED
filename = diff['new_path']
patch = diff['diff']
if not patch:
patch = load_large_diff(filename, new_file_content_str, original_file_content_str)
# count number of lines added and removed
patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str,
patch=patch,
filename=filename,
edit_type=edit_type,
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'],
num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines, ))
if invalid_files_names:
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
self.diff_files = diff_files
return diff_files
def get_files(self) -> list:
if not self.git_files:
self.git_files = [change['new_path'] for change in self.mr.changes()['changes']]
return self.git_files
def publish_description(self, pr_title: str, pr_body: str):
try:
self.mr.title = pr_title
self.mr.description = pr_body
self.mr.save()
except Exception as e:
get_logger().exception(f"Could not update merge request {self.id_mr} description: {e}")
def get_latest_commit_url(self):
return self.mr.commits().next().web_url
def get_comment_url(self, comment):
return f"{self.mr.web_url}#note_{comment.id}"
def publish_persistent_comment(self, pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True):
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
def publish_comment(self, mr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {mr_comment}")
return None
mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars)
comment = self.mr.notes.create({'body': mr_comment})
if is_temporary:
self.temp_comments.append(comment)
return comment
def edit_comment(self, comment, body: str):
body = self.limit_output_characters(body, self.max_comment_chars)
self.mr.notes.update(comment.id,{'body': body} )
def edit_comment_from_comment_id(self, comment_id: int, body: str):
body = self.limit_output_characters(body, self.max_comment_chars)
comment = self.mr.notes.get(comment_id)
comment.body = body
comment.save()
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
body = self.limit_output_characters(body, self.max_comment_chars)
discussion = self.mr.discussions.get(comment_id)
discussion.notes.create({'body': body})
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
body = self.limit_output_characters(body, self.max_comment_chars)
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
relevant_line_in_file)
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no, original_suggestion)
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None):
raise NotImplementedError("Gitlab provider does not support creating inline comments yet")
def create_inline_comments(self, comments: list[dict]):
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet")
def get_comment_body_from_comment_id(self, comment_id: int):
comment = self.mr.notes.get(comment_id).body
return comment
def send_inline_comment(self, body: str, edit_type: str, found: bool, relevant_file: str,
relevant_line_in_file: str,
source_line_no: int, target_file: str, target_line_no: int,
original_suggestion=None) -> None:
if not found:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
# in order to have exact sha's we have to find correct diff for this change
diff = self.get_relevant_diff(relevant_file, relevant_line_in_file)
if diff is None:
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}")
pos_obj = {'position_type': 'text',
'new_path': target_file.filename,
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
'base_sha': diff.base_commit_sha, 'start_sha': diff.start_commit_sha, 'head_sha': diff.head_commit_sha}
if edit_type == 'deletion':
pos_obj['old_line'] = source_line_no - 1
elif edit_type == 'addition':
pos_obj['new_line'] = target_line_no - 1
else:
pos_obj['new_line'] = target_line_no - 1
pos_obj['old_line'] = source_line_no - 1
get_logger().debug(f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}")
try:
self.mr.discussions.create({'body': body, 'position': pos_obj})
except Exception as e:
try:
# fallback - create a general note on the file in the MR
if 'suggestion_orig_location' in original_suggestion:
line_start = original_suggestion['suggestion_orig_location']['start_line']
line_end = original_suggestion['suggestion_orig_location']['end_line']
old_code_snippet = original_suggestion['prev_code_snippet']
new_code_snippet = original_suggestion['new_code_snippet']
content = original_suggestion['suggestion_summary']
label = original_suggestion['category']
if 'score' in original_suggestion:
score = original_suggestion['score']
else:
score = 7
else:
line_start = original_suggestion['relevant_lines_start']
line_end = original_suggestion['relevant_lines_end']
old_code_snippet = original_suggestion['existing_code']
new_code_snippet = original_suggestion['improved_code']
content = original_suggestion['suggestion_content']
label = original_suggestion['label']
score = original_suggestion.get('score', 7)
if hasattr(self, 'main_language'):
language = self.main_language
else:
language = ''
link = self.get_line_link(relevant_file, line_start, line_end)
body_fallback =f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
body_fallback +=f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`"
body_fallback+="</details>\n\n"
diff_patch = difflib.unified_diff(old_code_snippet.split('\n'),
new_code_snippet.split('\n'), n=999)
patch_orig = "\n".join(diff_patch)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
body_fallback += diff_code
# Create a general note on the file in the MR
self.mr.notes.create({
'body': body_fallback,
'position': {
'base_sha': diff.base_commit_sha,
'start_sha': diff.start_commit_sha,
'head_sha': diff.head_commit_sha,
'position_type': 'text',
'file_path': f'{target_file.filename}',
}
})
get_logger().debug(f"Created fallback comment in MR {self.id_mr} with position {pos_obj}")
# get_logger().debug(
# f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)")
except Exception as e:
get_logger().exception(f"Failed to create comment in MR {self.id_mr}")
def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: str) -> Optional[dict]:
changes = self.mr.changes() # Retrieve the changes for the merge request once
if not changes:
get_logger().error('No changes found for the merge request.')
return None
all_diffs = self.mr.diffs.list(get_all=True)
if not all_diffs:
get_logger().error('No diffs found for the merge request.')
return None
for diff in all_diffs:
for change in changes['changes']:
if change['new_path'] == relevant_file and relevant_line_in_file in change['diff']:
return diff
get_logger().debug(
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
return self.last_diff # fallback to last_diff if no relevant diff is found
def publish_code_suggestions(self, code_suggestions: list) -> bool:
for suggestion in code_suggestions:
try:
if suggestion and 'original_suggestion' in suggestion:
original_suggestion = suggestion['original_suggestion']
else:
original_suggestion = suggestion
body = suggestion['body']
relevant_file = suggestion['relevant_file']
relevant_lines_start = suggestion['relevant_lines_start']
relevant_lines_end = suggestion['relevant_lines_end']
diff_files = self.get_diff_files()
target_file = None
for file in diff_files:
if file.filename == relevant_file:
if file.filename == relevant_file:
target_file = file
break
range = relevant_lines_end - relevant_lines_start # no need to add 1
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
lines = target_file.head_file.splitlines()
relevant_line_in_file = lines[relevant_lines_start - 1]
# edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(target_file,
# relevant_line_in_file)
# for code suggestions, we want to edit the new code
source_line_no = -1
target_line_no = relevant_lines_start + 1
found = True
edit_type = 'addition'
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no, original_suggestion)
except Exception as e:
get_logger().exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}")
# note that we publish suggestions one-by-one. so, if one fails, the rest will still be published
return True
def publish_file_comments(self, file_comments: list) -> bool:
pass
def search_line(self, relevant_file, relevant_line_in_file):
target_file = None
edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.get_diff_files():
if file.filename == relevant_file:
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
relevant_line_in_file)
return edit_type, found, source_line_no, target_file, target_line_no
def find_in_file(self, file, relevant_line_in_file):
edit_type = 'context'
source_line_no = 0
target_line_no = 0
found = False
target_file = file
patch = file.patch
patch_lines = patch.splitlines()
for line in patch_lines:
if line.startswith('@@'):
match = self.RE_HUNK_HEADER.match(line)
if not match:
continue
start_old, size_old, start_new, size_new, _ = match.groups()
source_line_no = int(start_old)
target_line_no = int(start_new)
continue
if line.startswith('-'):
source_line_no += 1
elif line.startswith('+'):
target_line_no += 1
elif line.startswith(' '):
source_line_no += 1
target_line_no += 1
if relevant_line_in_file in line:
found = True
edit_type = self.get_edit_type(line)
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:].lstrip() in line:
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
found = True
edit_type = self.get_edit_type(line)
break
return edit_type, found, source_line_no, target_file, target_line_no
def get_edit_type(self, relevant_line_in_file):
edit_type = 'context'
if relevant_line_in_file[0] == '-':
edit_type = 'deletion'
elif relevant_line_in_file[0] == '+':
edit_type = 'addition'
return edit_type
def remove_initial_comment(self):
try:
for comment in self.temp_comments:
self.remove_comment(comment)
except Exception as e:
get_logger().exception(f"Failed to remove temp comments, error: {e}")
def remove_comment(self, comment):
try:
comment.delete()
except Exception as e:
get_logger().exception(f"Failed to remove comment, error: {e}")
def get_title(self):
return self.mr.title
def get_languages(self):
languages = self.gl.projects.get(self.id_project).languages()
return languages
def get_pr_branch(self):
return self.mr.source_branch
def get_pr_owner_id(self) -> str | None:
if not self.gitlab_url or 'gitlab.com' in self.gitlab_url:
if not self.id_project:
return None
return self.id_project.split('/')[0]
# extract host name
host = urlparse(self.gitlab_url).hostname
return host
def get_pr_description_full(self):
return self.mr.description
def get_issue_comments(self):
return self.mr.notes.list(get_all=True)[::-1]
def get_repo_settings(self):
try:
contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch).decode()
return contents
except Exception:
return ""
def get_workspace_name(self):
return self.id_project.split('/')[0]
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
return True
def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[str, int]:
parsed_url = urlparse(merge_request_url)
path_parts = parsed_url.path.strip('/').split('/')
if 'merge_requests' not in path_parts:
raise ValueError("The provided URL does not appear to be a GitLab merge request URL")
mr_index = path_parts.index('merge_requests')
# Ensure there is an ID after 'merge_requests'
if len(path_parts) <= mr_index + 1:
raise ValueError("The provided URL does not contain a merge request ID")
try:
mr_id = int(path_parts[mr_index + 1])
except ValueError as e:
raise ValueError("Unable to convert merge request ID to integer") from e
# Handle special delimiter (-)
project_path = "/".join(path_parts[:mr_index])
if project_path.endswith('/-'):
project_path = project_path[:-2]
# Return the path before 'merge_requests' and the ID
return project_path, mr_id
def _get_merge_request(self):
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
return mr
def get_user_id(self):
return None
def publish_labels(self, pr_types):
try:
self.mr.labels = list(set(pr_types))
self.mr.save()
except Exception as e:
get_logger().warning(f"Failed to publish labels, error: {e}")
def publish_inline_comments(self, comments: list[dict]):
pass
def get_pr_labels(self, update=False):
return self.mr.labels
def get_repo_labels(self):
return self.gl.projects.get(self.id_project).labels.list()
def get_commit_messages(self):
"""
Retrieves the commit messages of a pull request.
Returns:
str: A string containing the commit messages of the pull request.
"""
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
try:
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
except Exception:
commit_messages_str = ""
if max_tokens:
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
return commit_messages_str
def get_pr_id(self):
try:
pr_id = self.mr.web_url
return pr_id
except:
return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
if relevant_line_start == -1:
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads"
elif relevant_line_end:
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}-{relevant_line_end}"
else:
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}"
return link
def generate_link_to_relevant_line_number(self, suggestion) -> str:
try:
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
relevant_line_str = suggestion['relevant_line'].rstrip()
if not relevant_line_str:
return ""
position, absolute_position = find_line_number_of_relevant_line_in_file \
(self.diff_files, relevant_file, relevant_line_str)
if absolute_position != -1:
# link to right file only
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{absolute_position}"
# # link to diff
# sha_file = hashlib.sha1(relevant_file.encode('utf-8')).hexdigest()
# link = f"{self.pr.web_url}/diffs#{sha_file}_{absolute_position}_{absolute_position}"
return link
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link, error: {e}")
return ""

View File

@ -0,0 +1,192 @@
from collections import Counter
from pathlib import Path
from typing import List
from git import Repo
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from utils.pr_agent.config_loader import _find_repository_root, get_settings
from utils.pr_agent.git_providers.git_provider import GitProvider
from utils.pr_agent.log import get_logger
class PullRequestMimic:
"""
This class mimics the PullRequest class from the PyGithub library for the LocalGitProvider.
"""
def __init__(self, title: str, diff_files: List[FilePatchInfo]):
self.title = title
self.diff_files = diff_files
class LocalGitProvider(GitProvider):
"""
This class implements the GitProvider interface for local git repositories.
It mimics the PR functionality of the GitProvider interface,
but does not require a hosted git repository.
Instead of providing a PR url, the user provides a local branch path to generate a diff-patch.
For the MVP it only supports the /review and /describe capabilities.
"""
def __init__(self, target_branch_name, incremental=False):
self.repo_path = _find_repository_root()
if self.repo_path is None:
raise ValueError('Could not find repository root')
self.repo = Repo(self.repo_path)
self.head_branch_name = self.repo.head.ref.name
self.target_branch_name = target_branch_name
self._prepare_repo()
self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
self.description_path = get_settings().get('local.description_path') \
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
self.review_path = get_settings().get('local.review_path') \
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
# inline code comments are not supported for local git repositories
get_settings().pr_reviewer.inline_code_comments = False
def _prepare_repo(self):
"""
Prepare the repository for PR-mimic generation.
"""
get_logger().debug('Preparing repository for PR-mimic generation...')
if self.repo.is_dirty():
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.')
if self.target_branch_name not in self.repo.heads:
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels',
'gfm_markdown']:
return False
return True
def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
create_patch=True,
R=True
)
diff_files = []
for diff_item in diffs:
if diff_item.a_blob is not None:
original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8')
else:
original_file_content_str = "" # empty file
if diff_item.b_blob is not None:
new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8')
else:
new_file_content_str = "" # empty file
edit_type = EDIT_TYPE.MODIFIED
if diff_item.new_file:
edit_type = EDIT_TYPE.ADDED
elif diff_item.deleted_file:
edit_type = EDIT_TYPE.DELETED
elif diff_item.renamed_file:
edit_type = EDIT_TYPE.RENAMED
diff_files.append(
FilePatchInfo(original_file_content_str,
new_file_content_str,
diff_item.diff.decode('utf-8'),
diff_item.b_path,
edit_type=edit_type,
old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path
)
)
self.diff_files = diff_files
return diff_files
def get_files(self) -> List[str]:
"""
Returns a list of files with changes in the diff.
"""
diff_index = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
R=True
)
# Get the list of changed files
diff_files = [item.a_path for item in diff_index]
return diff_files
def publish_description(self, pr_title: str, pr_body: str):
with open(self.description_path, "w") as file:
# Write the string to the file
file.write(pr_title + '\n' + pr_body)
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
with open(self.review_path, "w") as file:
# Write the string to the file
file.write(pr_comment)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
def publish_code_suggestion(self, body: str, relevant_file: str,
relevant_lines_start: int, relevant_lines_end: int):
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
def publish_code_suggestions(self, code_suggestions: list) -> bool:
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
def publish_labels(self, labels):
pass # Not applicable to the local git provider, but required by the interface
def remove_initial_comment(self):
pass # Not applicable to the local git provider, but required by the interface
def remove_comment(self, comment):
pass # Not applicable to the local git provider, but required by the interface
def add_eyes_reaction(self, comment):
pass # Not applicable to the local git provider, but required by the interface
def get_commit_messages(self):
pass # Not applicable to the local git provider, but required by the interface
def get_repo_settings(self):
pass # Not applicable to the local git provider, but required by the interface
def remove_reaction(self, comment):
pass # Not applicable to the local git provider, but required by the interface
def get_languages(self):
"""
Calculate percentage of languages in repository. Used for hunk prioritisation.
"""
# Get all files in repository
filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob']
# Identify language by file extension and count
lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()])
# Convert counts to percentages
total_files = len(filepaths)
lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()}
return lang_percentage
def get_pr_branch(self):
return self.repo.head
def get_user_id(self):
return -1 # Not used anywhere for the local provider, but required by the interface
def get_pr_description_full(self):
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
# Get the commit messages and concatenate
commit_messages = " ".join([commit.message for commit in commits_diff])
# TODO Handle the description better - maybe use gpt-3.5 summarisation here?
return commit_messages[:200] # Use max 200 characters
def get_pr_title(self):
"""
Substitutes the branch-name as the PR-mimic title.
"""
return self.head_branch_name
def get_issue_comments(self):
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
def get_pr_labels(self, update=False):
raise NotImplementedError('Getting labels is not implemented for the local git provider')

View File

@ -0,0 +1,102 @@
import copy
import os
import tempfile
from dynaconf import Dynaconf
from starlette_context import context
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (get_git_provider_with_context)
from utils.pr_agent.log import get_logger
def apply_repo_settings(pr_url):
git_provider = get_git_provider_with_context(pr_url)
if get_settings().config.use_repo_settings_file:
repo_settings_file = None
try:
try:
repo_settings = context.get("repo_settings", None)
except Exception:
repo_settings = None
pass
if repo_settings is None: # None is different from "", which is a valid value
repo_settings = git_provider.get_repo_settings()
try:
context["repo_settings"] = repo_settings
except Exception:
pass
error_local = None
if repo_settings:
repo_settings_file = None
category = 'local'
try:
fd, repo_settings_file = tempfile.mkstemp(suffix='.toml')
os.write(fd, repo_settings)
new_settings = Dynaconf(settings_files=[repo_settings_file])
for section, contents in new_settings.as_dict().items():
section_dict = copy.deepcopy(get_settings().as_dict().get(section, {}))
for key, value in contents.items():
section_dict[key] = value
get_settings().unset(section)
get_settings().set(section, section_dict, merge=False)
get_logger().info(f"Applying repo settings:\n{new_settings.as_dict()}")
except Exception as e:
get_logger().warning(f"Failed to apply repo {category} settings, error: {str(e)}")
error_local = {'error': str(e), 'settings': repo_settings, 'category': category}
if error_local:
handle_configurations_errors([error_local], git_provider)
except Exception as e:
get_logger().exception("Failed to apply repo settings", e)
finally:
if repo_settings_file:
try:
os.remove(repo_settings_file)
except Exception as e:
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e)
# enable switching models with a short definition
if get_settings().config.model.lower() == 'claude-3-5-sonnet':
set_claude_model()
def handle_configurations_errors(config_errors, git_provider):
try:
if not any(config_errors):
return
for err in config_errors:
if err:
configuration_file_content = err['settings'].decode()
err_message = err['error']
config_type = err['category']
header = f"❌ **PR-Agent failed to apply '{config_type}' repo settings**"
body = f"{header}\n\nThe configuration file needs to be a valid [TOML](https://qodo-merge-docs.qodo.ai/usage-guide/configuration_options/), please fix it.\n\n"
body += f"___\n\n**Error message:**\n`{err_message}`\n\n"
if git_provider.is_supported("gfm_markdown"):
body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>"
else:
body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n"
get_logger().warning(f"Sending a 'configuration error' comment to the PR", artifact={'body': body})
# git_provider.publish_comment(body)
if hasattr(git_provider, 'publish_persistent_comment'):
git_provider.publish_persistent_comment(body,
initial_header=header,
update_header=False,
final_update_message=False)
else:
git_provider.publish_comment(body)
except Exception as e:
get_logger().exception(f"Failed to handle configurations errors", e)
def set_claude_model():
"""
set the claude-sonnet-3.5 model easily (even by users), just by stating: --config.model='claude-3-5-sonnet'
"""
model_claude = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
get_settings().set('config.model', model_claude)
get_settings().set('config.model_weak', model_claude)
get_settings().set('config.fallback_models', [model_claude])

View File

@ -0,0 +1,14 @@
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.identity_providers.default_identity_provider import \
DefaultIdentityProvider
_IDENTITY_PROVIDERS = {
'default': DefaultIdentityProvider
}
def get_identity_provider():
identity_provider_id = get_settings().get("CONFIG.IDENTITY_PROVIDER", "default")
if identity_provider_id not in _IDENTITY_PROVIDERS:
raise ValueError(f"Unknown identity provider: {identity_provider_id}")
return _IDENTITY_PROVIDERS[identity_provider_id]()

View File

@ -0,0 +1,10 @@
from utils.pr_agent.identity_providers.identity_provider import (Eligibility,
IdentityProvider)
class DefaultIdentityProvider(IdentityProvider):
def verify_eligibility(self, git_provider, git_provider_id, pr_url):
return Eligibility.ELIGIBLE
def inc_invocation_count(self, git_provider, git_provider_id):
pass

View File

@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from enum import Enum
class Eligibility(Enum):
NOT_ELIGIBLE = 0
ELIGIBLE = 1
TRIAL = 2
class IdentityProvider(ABC):
@abstractmethod
def verify_eligibility(self, git_provider, git_provier_id, pr_url):
pass
@abstractmethod
def inc_invocation_count(self, git_provider, git_provider_id):
pass

View File

@ -0,0 +1,64 @@
import logging
import os
import sys
from enum import Enum
from loguru import logger
from utils.pr_agent.config_loader import get_settings
class LoggingFormat(str, Enum):
CONSOLE = "CONSOLE"
JSON = "JSON"
def json_format(record: dict) -> str:
return record["message"]
def analytics_filter(record: dict) -> bool:
return record.get("extra", {}).get("analytics", False)
def inv_analytics_filter(record: dict) -> bool:
return not record.get("extra", {}).get("analytics", False)
def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE):
level: int = logging.getLevelName(level.upper())
if type(level) is not int:
level = logging.INFO
if fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0": # better debugging github_app
logger.remove(None)
logger.add(
sys.stdout,
filter=inv_analytics_filter,
level=level,
format="{message}",
colorize=False,
serialize=True,
)
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
logger.remove(None)
logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter)
log_folder = get_settings().get("CONFIG.ANALYTICS_FOLDER", "")
if log_folder:
pid = os.getpid()
log_file = os.path.join(log_folder, f"pr-agent.{pid}.log")
logger.add(
log_file,
filter=analytics_filter,
level=level,
format="{message}",
colorize=False,
serialize=True,
)
return logger
def get_logger(*args, **kwargs):
return logger

View File

@ -0,0 +1,17 @@
from utils.pr_agent.config_loader import get_settings
def get_secret_provider():
if not get_settings().get("CONFIG.SECRET_PROVIDER"):
return None
provider_id = get_settings().config.secret_provider
if provider_id == 'google_cloud_storage':
try:
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import \
GoogleCloudStorageSecretProvider
return GoogleCloudStorageSecretProvider()
except Exception as e:
raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e
else:
raise ValueError("Unknown SECRET_PROVIDER")

View File

@ -0,0 +1,34 @@
import ujson
from google.cloud import storage
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger
from utils.pr_agent.secret_providers.secret_provider import SecretProvider
class GoogleCloudStorageSecretProvider(SecretProvider):
def __init__(self):
try:
self.client = storage.Client.from_service_account_info(ujson.loads(get_settings().google_cloud_storage.
service_account))
self.bucket_name = get_settings().google_cloud_storage.bucket_name
self.bucket = self.client.bucket(self.bucket_name)
except Exception as e:
get_logger().error(f"Failed to initialize Google Cloud Storage Secret Provider: {e}")
raise e
def get_secret(self, secret_name: str) -> str:
try:
blob = self.bucket.blob(secret_name)
return blob.download_as_string()
except Exception as e:
get_logger().warning(f"Failed to get secret {secret_name} from Google Cloud Storage: {e}")
return ""
def store_secret(self, secret_name: str, secret_value: str):
try:
blob = self.bucket.blob(secret_name)
blob.upload_from_string(secret_value)
except Exception as e:
get_logger().error(f"Failed to store secret {secret_name} in Google Cloud Storage: {e}")
raise e

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
class SecretProvider(ABC):
@abstractmethod
def get_secret(self, secret_name: str) -> str:
pass
@abstractmethod
def store_secret(self, secret_name: str, secret_value: str):
pass

View File

View File

@ -0,0 +1,34 @@
{
"name": "CodiumAI PR-Agent",
"description": "CodiumAI PR-Agent",
"key": "app_key",
"vendor": {
"name": "CodiumAI",
"url": "https://codium.ai"
},
"authentication": {
"type": "jwt"
},
"baseUrl": "base_url",
"lifecycle": {
"installed": "/installed",
"uninstalled": "/uninstalled"
},
"scopes": [
"account",
"repository:write",
"pullrequest:write",
"wiki"
],
"contexts": [
"account"
],
"modules": {
"webhooks": [
{
"event": "*",
"url": "/webhook"
}
]
}
}

View File

@ -0,0 +1,148 @@
# This file contains the code for the Azure DevOps Server webhook server.
# The server listens for incoming webhooks from Azure DevOps Server and forwards them to the PR Agent.
# ADO webhook documentation: https://learn.microsoft.com/en-us/azure/devops/service-hooks/services/webhooks?view=azure-devops
import json
import os
import re
import secrets
from urllib.parse import unquote
import uvicorn
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette import status
from starlette.background import BackgroundTasks
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent, command2class
from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
security = HTTPBasic()
router = APIRouter()
available_commands_rgx = re.compile(r"^\/(" + "|".join(command2class.keys()) + r")\s*")
azure_devops_server = get_settings().get("azure_devops_server")
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
def handle_request(
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
):
log_context["action"] = body
log_context["api_url"] = url
async def inner():
try:
with get_logger().contextualize(**log_context):
await PRAgent().handle_request(url, body)
except Exception as e:
get_logger().error(f"Failed to handle webhook: {e}")
background_tasks.add_task(inner)
# currently only basic auth is supported with azure webhooks
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
if not (is_user_ok and is_pass_ok):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Incorrect username or password.',
headers={'WWW-Authenticate': 'Basic'},
)
async def _perform_commands_azure(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
return
commands = get_settings().get(f"azure_devops_server.{commands_conf}")
get_settings().set("config.is_auto_command", True)
for command in commands:
try:
split_command = command.split(" ")
command = split_command[0]
args = split_command[1:]
other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args)
get_logger().info(f"Performing command: {new_command}")
with get_logger().contextualize(**log_context):
await agent.handle_request(api_url, new_command)
except Exception as e:
get_logger().error(f"Failed to perform command {command}: {e}")
@router.post("/", dependencies=[Depends(authorize)])
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "azure_devops_server"}
data = await request.json()
get_logger().info(json.dumps(data))
actions = []
if data["eventType"] == "git.pullrequest.created":
# API V1 (latest)
pr_url = unquote(data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git"))
log_context["event"] = data["eventType"]
log_context["api_url"] = pr_url
await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context)
return
elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]:
if available_commands_rgx.match(data["resource"]["comment"]["content"]):
if(data["resourceVersion"] == "2.0"):
repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
pr_url = unquote(f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}')
actions = [data["resource"]["comment"]["content"]]
else:
# API V1 not supported as it does not contain the PR URL
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})),
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps({"message": "Unsupported command"}),
)
else:
return JSONResponse(
status_code=status.HTTP_204_NO_CONTENT,
content=json.dumps({"message": "Unsupported event"}),
)
log_context["event"] = data["eventType"]
log_context["api_url"] = pr_url
for action in actions:
try:
handle_request(background_tasks, pr_url, action, log_context)
except Exception as e:
get_logger().error("Azure DevOps Trigger failed. Error:" + str(e))
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=json.dumps({"message": "Internal server error"}),
)
return JSONResponse(
status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggered successfully"})
)
@router.get("/")
async def root():
return {"status": "ok"}
def start():
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
if __name__ == "__main__":
start()

View File

@ -0,0 +1,272 @@
import base64
import copy
import hashlib
import json
import os
import re
import time
import jwt
import requests
import uvicorn
from fastapi import APIRouter, FastAPI, Request, Response
from starlette.background import BackgroundTasks
from starlette.middleware import Middleware
from starlette.responses import JSONResponse
from starlette_context import context
from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.config_loader import get_settings, global_settings
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.identity_providers import get_identity_provider
from utils.pr_agent.identity_providers.identity_provider import Eligibility
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
from utils.pr_agent.secret_providers import get_secret_provider
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
router = APIRouter()
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
async def get_bearer_token(shared_secret: str, client_key: str):
try:
now = int(time.time())
url = "https://bitbucket.org/site/oauth2/access_token"
canonical_url = "GET&/site/oauth2/access_token&"
qsh = hashlib.sha256(canonical_url.encode("utf-8")).hexdigest()
app_key = get_settings().bitbucket.app_key
payload = {
"iss": app_key,
"iat": now,
"exp": now + 240,
"qsh": qsh,
"sub": client_key,
}
token = jwt.encode(payload, shared_secret, algorithm="HS256")
payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt'
headers = {
'Authorization': f'JWT {token}',
'Content-Type': 'application/x-www-form-urlencoded'
}
response = requests.request("POST", url, headers=headers, data=payload)
bearer_token = response.json()["access_token"]
return bearer_token
except Exception as e:
get_logger().error(f"Failed to get bearer token: {e}")
raise e
@router.get("/")
async def handle_manifest(request: Request, response: Response):
cur_dir = os.path.dirname(os.path.abspath(__file__))
manifest = open(os.path.join(cur_dir, "atlassian-connect.json"), "rt").read()
try:
manifest = manifest.replace("app_key", get_settings().bitbucket.app_key)
manifest = manifest.replace("base_url", get_settings().bitbucket.base_url)
except:
get_logger().error("Failed to replace api_key in Bitbucket manifest, trying to continue")
manifest_obj = json.loads(manifest)
return JSONResponse(manifest_obj)
def _get_username(data):
actor = data.get("data", {}).get("actor", {})
if actor:
if "username" in actor:
return actor["username"]
elif "display_name" in actor:
return actor["display_name"]
elif "nickname" in actor:
return actor["nickname"]
return ""
async def _perform_commands_bitbucket(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
return
if data.get("event", "") == "pullrequest:created":
if not should_process_pr_logic(data):
return
commands = get_settings().get(f"bitbucket_app.{commands_conf}", {})
get_settings().set("config.is_auto_command", True)
for command in commands:
try:
split_command = command.split(" ")
command = split_command[0]
args = split_command[1:]
other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args)
get_logger().info(f"Performing command: {new_command}")
with get_logger().contextualize(**log_context):
await agent.handle_request(api_url, new_command)
except Exception as e:
get_logger().error(f"Failed to perform command {command}: {e}")
def is_bot_user(data) -> bool:
try:
actor = data.get("data", {}).get("actor", {})
# allow actor type: user . if it's "AppUser" or "team" then it is a bot user
allowed_actor_types = {"user"}
if actor and actor["type"].lower() not in allowed_actor_types:
get_logger().info(f"BitBucket actor type is not 'user', skipping: {actor}")
return True
except Exception as e:
get_logger().error(f"Failed 'is_bot_user' logic: {e}")
return False
def should_process_pr_logic(data) -> bool:
try:
pr_data = data.get("data", {}).get("pullrequest", {})
title = pr_data.get("title", "")
source_branch = pr_data.get("source", {}).get("branch", {}).get("name", "")
target_branch = pr_data.get("destination", {}).get("branch", {}).get("name", "")
sender = _get_username(data)
# logic to ignore PRs from specific users
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender:
if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
return False
# logic to ignore PRs with specific titles
if title:
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
if not isinstance(ignore_pr_title_re, list):
ignore_pr_title_re = [ignore_pr_title_re]
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
return False
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
if (ignore_pr_source_branches or ignore_pr_target_branches):
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
get_logger().info(
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
return False
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
get_logger().info(
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
return False
except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
return True
@router.post("/webhook")
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request):
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
log_context = {"server_type": "bitbucket_app", "app_name": app_name}
get_logger().debug(request.headers)
jwt_header = request.headers.get("authorization", None)
if jwt_header:
input_jwt = jwt_header.split(" ")[1]
data = await request.json()
get_logger().debug(data)
async def inner():
try:
# ignore bot users
if is_bot_user(data):
return "OK"
# Check if the PR should be processed
if data.get("event", "") == "pullrequest:created":
if not should_process_pr_logic(data):
return "OK"
# Get the username of the sender
log_context["sender"] = _get_username(data)
sender_id = data.get("data", {}).get("actor", {}).get("account_id", "")
log_context["sender_id"] = sender_id
jwt_parts = input_jwt.split(".")
claim_part = jwt_parts[1]
claim_part += "=" * (-len(claim_part) % 4)
decoded_claims = base64.urlsafe_b64decode(claim_part)
claims = json.loads(decoded_claims)
client_key = claims["iss"]
secrets = json.loads(secret_provider.get_secret(client_key))
shared_secret = secrets["shared_secret"]
jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"])
bearer_token = await get_bearer_token(shared_secret, client_key)
context['bitbucket_bearer_token'] = bearer_token
context["settings"] = copy.deepcopy(global_settings)
event = data["event"]
agent = PRAgent()
if event == "pullrequest:created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url
log_context["event"] = "pull_request"
if pr_url:
with get_logger().contextualize(**log_context):
apply_repo_settings(pr_url)
if get_identity_provider().verify_eligibility("bitbucket",
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
if get_settings().get("bitbucket_app.pr_commands"):
await _perform_commands_bitbucket("pr_commands", PRAgent(), pr_url, log_context, data)
elif event == "pullrequest:comment_created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url
log_context["event"] = "comment"
comment_body = data["data"]["comment"]["content"]["raw"]
with get_logger().contextualize(**log_context):
if get_identity_provider().verify_eligibility("bitbucket",
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
await agent.handle_request(pr_url, comment_body)
except Exception as e:
get_logger().error(f"Failed to handle webhook: {e}")
background_tasks.add_task(inner)
return "OK"
@router.get("/webhook")
async def handle_github_webhooks(request: Request, response: Response):
return "Webhook server online!"
@router.post("/installed")
async def handle_installed_webhooks(request: Request, response: Response):
try:
get_logger().info("handle_installed_webhooks")
get_logger().info(request.headers)
data = await request.json()
get_logger().info(data)
shared_secret = data["sharedSecret"]
client_key = data["clientKey"]
username = data["principal"]["username"]
secrets = {
"shared_secret": shared_secret,
"client_key": client_key
}
secret_provider.store_secret(username, json.dumps(secrets))
except Exception as e:
get_logger().error(f"Failed to register user: {e}")
return JSONResponse({"error": "Unable to register user"}, status_code=500)
@router.post("/uninstalled")
async def handle_uninstalled_webhooks(request: Request, response: Response):
get_logger().info("handle_uninstalled_webhooks")
data = await request.json()
get_logger().info(data)
def start():
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
get_settings().set("CONFIG.GIT_PROVIDER", "bitbucket")
get_settings().set("PR_DESCRIPTION.PUBLISH_DESCRIPTION_AS_COMMENT", True)
middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "3000")))
if __name__ == '__main__':
start()

View File

@ -0,0 +1,164 @@
import ast
import json
import os
from typing import List
import uvicorn
from fastapi import APIRouter, FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import RedirectResponse
from starlette import status
from starlette.background import BackgroundTasks
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
from utils.pr_agent.servers.utils import verify_signature
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
router = APIRouter()
def handle_request(
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
):
log_context["action"] = body
log_context["api_url"] = url
async def inner():
try:
with get_logger().contextualize(**log_context):
await PRAgent().handle_request(url, body)
except Exception as e:
get_logger().error(f"Failed to handle webhook: {e}")
background_tasks.add_task(inner)
@router.post("/")
async def redirect_to_webhook():
return RedirectResponse(url="/webhook")
@router.post("/webhook")
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "bitbucket_server"}
data = await request.json()
get_logger().info(json.dumps(data))
webhook_secret = get_settings().get("BITBUCKET_SERVER.WEBHOOK_SECRET", None)
if webhook_secret:
body_bytes = await request.body()
if body_bytes.decode('utf-8') == '{"test": true}':
return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "connection test successful"})
)
signature_header = request.headers.get("x-hub-signature", None)
verify_signature(body_bytes, webhook_secret, signature_header)
pr_id = data["pullRequest"]["id"]
repository_name = data["pullRequest"]["toRef"]["repository"]["slug"]
project_name = data["pullRequest"]["toRef"]["repository"]["project"]["key"]
bitbucket_server = get_settings().get("BITBUCKET_SERVER.URL")
pr_url = f"{bitbucket_server}/projects/{project_name}/repos/{repository_name}/pull-requests/{pr_id}"
log_context["api_url"] = pr_url
log_context["event"] = "pull_request"
commands_to_run = []
if data["eventKey"] == "pr:opened":
apply_repo_settings(pr_url)
if get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", **log_context)
return
get_settings().set("config.is_auto_command", True)
commands_to_run.extend(_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS'))
elif data["eventKey"] == "pr:comment:added":
commands_to_run.append(data["comment"]["text"])
else:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps({"message": "Unsupported event"}),
)
async def inner():
try:
await _run_commands_sequentially(commands_to_run, pr_url, log_context)
except Exception as e:
get_logger().error(f"Failed to handle webhook: {e}")
background_tasks.add_task(inner)
return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})
)
async def _run_commands_sequentially(commands: List[str], url: str, log_context: dict):
get_logger().info(f"Running commands sequentially: {commands}")
if commands is None:
return
for command in commands:
try:
body = _process_command(command, url)
log_context["action"] = body
log_context["api_url"] = url
with get_logger().contextualize(**log_context):
await PRAgent().handle_request(url, body)
except Exception as e:
get_logger().error(f"Failed to handle command: {command} , error: {e}")
def _process_command(command: str, url) -> str:
# don't think we need this
apply_repo_settings(url)
# Process the command string
split_command = command.split(" ")
command = split_command[0]
args = split_command[1:]
# do I need this? if yes, shouldn't this be done in PRAgent?
other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args)
return new_command
def _to_list(command_string: str) -> list:
try:
# Use ast.literal_eval to safely parse the string into a list
commands = ast.literal_eval(command_string)
# Check if the parsed object is a list of strings
if isinstance(commands, list) and all(isinstance(cmd, str) for cmd in commands):
return commands
else:
raise ValueError("Parsed data is not a list of strings.")
except (SyntaxError, ValueError, TypeError) as e:
raise ValueError(f"Invalid command string: {e}")
def _get_commands_list_from_settings(setting_key:str ) -> list:
try:
return get_settings().get(setting_key, [])
except ValueError as e:
get_logger().error(f"Failed to get commands list from settings {setting_key}: {e}")
@router.get("/")
async def root():
return {"status": "ok"}
def start():
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
if __name__ == "__main__":
start()

View File

@ -0,0 +1,77 @@
import copy
from enum import Enum
from json import JSONDecodeError
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException
from pydantic import BaseModel
from starlette.middleware import Middleware
from starlette_context import context
from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.config_loader import get_settings, global_settings
from utils.pr_agent.log import get_logger, setup_logger
setup_logger()
router = APIRouter()
class Action(str, Enum):
review = "review"
describe = "describe"
ask = "ask"
improve = "improve"
reflect = "reflect"
answer = "answer"
class Item(BaseModel):
refspec: str
project: str
msg: str
@router.post("/api/v1/gerrit/{action}")
async def handle_gerrit_request(action: Action, item: Item):
get_logger().debug("Received a Gerrit request")
context["settings"] = copy.deepcopy(global_settings)
if action == Action.ask:
if not item.msg:
return HTTPException(
status_code=400,
detail="msg is required for ask command"
)
await PRAgent().handle_request(
f"{item.project}:{item.refspec}",
f"/{item.msg.strip()}"
)
async def get_body(request):
try:
body = await request.json()
except JSONDecodeError as e:
get_logger().error("Error parsing request body", e)
return {}
return body
@router.get("/")
async def root():
return {"status": "ok"}
def start():
# to prevent adding help messages with the output
get_settings().set("CONFIG.CLI_MODE", True)
middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=3000)
if __name__ == '__main__':
start()

View File

@ -0,0 +1,160 @@
import asyncio
import json
import os
from typing import Union
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import get_git_provider
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.github_app import handle_line_comments
from utils.pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from utils.pr_agent.tools.pr_description import PRDescription
from utils.pr_agent.tools.pr_reviewer import PRReviewer
def is_true(value: Union[str, bool]) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() == 'true'
return False
def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str, bool]:
try:
value = get_settings().get(key, default)
except AttributeError: # TBD still need to debug why this happens on GitHub Actions
value = os.getenv(key, None) or os.getenv(key.upper(), None) or os.getenv(key.lower(), None) or default
return value
async def run_action():
# Get environment variables
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
OPENAI_KEY = os.environ.get('OPENAI_KEY') or os.environ.get('OPENAI.KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG') or os.environ.get('OPENAI.ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
# get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
# Check if required environment variables are set
if not GITHUB_EVENT_NAME:
print("GITHUB_EVENT_NAME not set")
return
if not GITHUB_EVENT_PATH:
print("GITHUB_EVENT_PATH not set")
return
if not GITHUB_TOKEN:
print("GITHUB_TOKEN not set")
return
# Set the environment variables in the settings
if OPENAI_KEY:
get_settings().set("OPENAI.KEY", OPENAI_KEY)
else:
# Might not be set if the user is using models not from OpenAI
print("OPENAI_KEY not set")
if OPENAI_ORG:
get_settings().set("OPENAI.ORG", OPENAI_ORG)
get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user")
enable_output = get_setting_or_env("GITHUB_ACTION_CONFIG.ENABLE_OUTPUT", True)
get_settings().set("GITHUB_ACTION_CONFIG.ENABLE_OUTPUT", enable_output)
# Load the event payload
try:
with open(GITHUB_EVENT_PATH, 'r') as f:
event_payload = json.load(f)
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
return
try:
get_logger().info("Applying repo settings")
pr_url = event_payload.get("pull_request", {}).get("html_url")
if pr_url:
apply_repo_settings(pr_url)
get_logger().info(f"enable_custom_labels: {get_settings().config.enable_custom_labels}")
except Exception as e:
get_logger().info(f"github action: failed to apply repo settings: {e}")
# Handle pull request opened event
if GITHUB_EVENT_NAME == "pull_request" or GITHUB_EVENT_NAME == "pull_request_target":
action = event_payload.get("action")
# Retrieve the list of actions from the configuration
pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"])
if action in pr_actions:
pr_url = event_payload.get("pull_request", {}).get("url")
if pr_url:
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
if auto_review is None:
auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None)
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
if auto_describe is None:
auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None)
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
if auto_improve is None:
auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None)
# Set the configuration for auto actions
get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto
get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled
get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}")
# invoke by default all three tools
if auto_describe is None or is_true(auto_describe):
await PRDescription(pr_url).run()
if auto_review is None or is_true(auto_review):
await PRReviewer(pr_url).run()
if auto_improve is None or is_true(auto_improve):
await PRCodeSuggestions(pr_url).run()
else:
get_logger().info(f"Skipping action: {action}")
# Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment" or GITHUB_EVENT_NAME == "pull_request_review_comment":
action = event_payload.get("action")
if action in ["created", "edited"]:
comment_body = event_payload.get("comment", {}).get("body")
try:
if GITHUB_EVENT_NAME == "pull_request_review_comment":
if '/ask' in comment_body:
comment_body = handle_line_comments(event_payload, comment_body)
except Exception as e:
get_logger().error(f"Failed to handle line comments: {e}")
return
if comment_body:
is_pr = False
disable_eyes = False
# check if issue is pull request
if event_payload.get("issue", {}).get("pull_request"):
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
is_pr = True
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
url = event_payload.get("comment", {}).get("pull_request_url")
is_pr = True
disable_eyes = True
else:
url = event_payload.get("issue", {}).get("url")
if url:
body = comment_body.strip().lower()
comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=url)
if is_pr:
await PRAgent().handle_request(
url, body, notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes
)
)
else:
await PRAgent().handle_request(url, body)
if __name__ == '__main__':
asyncio.run(run_action())

View File

@ -0,0 +1,424 @@
import asyncio.locks
import copy
import os
import re
import uuid
from typing import Any, Dict, Tuple
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
from starlette.background import BackgroundTasks
from starlette.middleware import Middleware
from starlette_context import context
from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.config_loader import get_settings, global_settings
from utils.pr_agent.git_providers import (get_git_provider,
get_git_provider_with_context)
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.identity_providers import get_identity_provider
from utils.pr_agent.identity_providers.identity_provider import Eligibility
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
from utils.pr_agent.servers.utils import DefaultDictWithTimeout, verify_signature
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
base_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
build_number_path = os.path.join(base_path, "build_number.txt")
if os.path.exists(build_number_path):
with open(build_number_path) as f:
build_number = f.read().strip()
else:
build_number = "unknown"
router = APIRouter()
@router.post("/api/v1/github_webhooks")
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request, response: Response):
"""
Receives and processes incoming GitHub webhook requests.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
processing.
"""
get_logger().debug("Received a GitHub webhook")
body = await get_body(request)
installation_id = body.get("installation", {}).get("id")
context["installation_id"] = installation_id
context["settings"] = copy.deepcopy(global_settings)
context["git_provider"] = {}
background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None))
return {}
@router.post("/api/v1/marketplace_webhooks")
async def handle_marketplace_webhooks(request: Request, response: Response):
body = await get_body(request)
get_logger().info(f'Request body:\n{body}')
async def get_body(request):
try:
body = await request.json()
except Exception as e:
get_logger().error("Error parsing request body", e)
raise HTTPException(status_code=400, detail="Error parsing request body") from e
webhook_secret = getattr(get_settings().github, 'webhook_secret', None)
if webhook_secret:
body_bytes = await request.body()
signature_header = request.headers.get('x-hub-signature-256', None)
verify_signature(body_bytes, webhook_secret, signature_header)
return body
_duplicate_push_triggers = DefaultDictWithTimeout(ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.Condition, ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
async def handle_comments_on_pr(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
if "comment" not in body:
return {}
comment_body = body.get("comment", {}).get("body")
if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"):
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
comment_body_split = comment_body.split('/ask')
comment_body = '/ask' + comment_body_split[1] +' \n' +comment_body_split[0].strip().lstrip('>')
get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}")
else:
get_logger().info("Ignoring comment not starting with /")
return {}
disable_eyes = False
if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]:
api_url = body["issue"]["pull_request"]["url"]
elif "comment" in body and "pull_request_url" in body["comment"]:
api_url = body["comment"]["pull_request_url"]
try:
if ('/ask' in comment_body and
'subject_type' in body["comment"] and body["comment"]["subject_type"] == "line"):
# comment on a code line in the "files changed" tab
comment_body = handle_line_comments(body, comment_body)
disable_eyes = True
except Exception as e:
get_logger().error(f"Failed to handle line comments: {e}")
else:
return {}
log_context["api_url"] = api_url
comment_id = body.get("comment", {}).get("id")
provider = get_git_provider_with_context(pr_url=api_url)
with get_logger().contextualize(**log_context):
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
await agent.handle_request(api_url, comment_body,
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
else:
get_logger().info(f"User {sender=} is not eligible to process comment on PR {api_url=}")
async def handle_new_pr_opened(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
title = body.get("pull_request", {}).get("title", "")
pull_request, api_url = _check_pull_request_event(action, body, log_context)
if not (pull_request and api_url):
get_logger().info(f"Invalid PR event: {action=} {api_url=}")
return {}
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review']
# logic to ignore PRs with specific titles (e.g. "[Auto] ...")
apply_repo_settings(api_url)
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
else:
get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}")
async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
event: str,
sender: str,
sender_id: str,
action: str,
log_context: Dict[str, Any],
agent: PRAgent):
pull_request, api_url = _check_pull_request_event(action, body, log_context)
if not (pull_request and api_url):
return {}
apply_repo_settings(api_url) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
if not get_settings().github_app.handle_push_trigger:
return {}
# TODO: do we still want to get the list of commits to filter bot/merge commits?
before_sha = body.get("before")
after_sha = body.get("after")
merge_commit_sha = pull_request.get("merge_commit_sha")
if before_sha == after_sha:
return {}
if get_settings().github_app.push_trigger_ignore_merge_commits and after_sha == merge_commit_sha:
return {}
# Prevent triggering multiple times for subsequent push triggers when one is enough:
# The first push will trigger the processing, and if there's a second push in the meanwhile it will wait.
# Any more events will be discarded, because they will all trigger the exact same processing on the PR.
# We let the second event wait instead of discarding it because while the first event was being processed,
# more commits may have been pushed that led to the subsequent events,
# so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting.
current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0)
max_active_tasks = 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
if current_active_tasks < max_active_tasks:
# first task can enter, and second tasks too if backlog is enabled
get_logger().info(
f"Continue processing push trigger for {api_url=} because there are {current_active_tasks} active tasks"
)
_duplicate_push_triggers[api_url] += 1
else:
get_logger().info(
f"Skipping push trigger for {api_url=} because another event already triggered the same processing"
)
return {}
async with _pending_task_duplicate_push_conditions[api_url]:
if current_active_tasks == 1:
# second task waits
get_logger().info(
f"Waiting to process push trigger for {api_url=} because the first task is still in progress"
)
await _pending_task_duplicate_push_conditions[api_url].wait()
get_logger().info(f"Finished waiting to process push trigger for {api_url=} - continue with flow")
try:
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
finally:
# release the waiting task block
async with _pending_task_duplicate_push_conditions[api_url]:
_pending_task_duplicate_push_conditions[api_url].notify(1)
_duplicate_push_triggers[api_url] -= 1
def handle_closed_pr(body, event, action, log_context):
pull_request = body.get("pull_request", {})
is_merged = pull_request.get("merged", False)
if not is_merged:
return
api_url = pull_request.get("url", "")
pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request)
log_context["api_url"] = api_url
get_logger().info("PR-Agent statistics for closed PR", analytics=True, pr_statistics=pr_statistics, **log_context)
def get_log_context(body, event, action, build_number):
sender = ""
sender_id = ""
sender_type = ""
try:
sender = body.get("sender", {}).get("login")
sender_id = body.get("sender", {}).get("id")
sender_type = body.get("sender", {}).get("type")
repo = body.get("repository", {}).get("full_name", "")
git_org = body.get("organization", {}).get("login", "")
installation_id = body.get("installation", {}).get("id", "")
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app",
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name,
"repo": repo, "git_org": git_org, "installation_id": installation_id}
except Exception as e:
get_logger().error("Failed to get log context", e)
log_context = {}
return log_context, sender, sender_id, sender_type
def is_bot_user(sender, sender_type):
try:
# logic to ignore PRs opened by bot
if get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) and sender_type == "Bot":
if 'pr-agent' not in sender:
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
return True
except Exception as e:
get_logger().error(f"Failed 'is_bot_user' logic: {e}")
return False
def should_process_pr_logic(body) -> bool:
try:
pull_request = body.get("pull_request", {})
title = pull_request.get("title", "")
pr_labels = pull_request.get("labels", [])
source_branch = pull_request.get("head", {}).get("ref", "")
target_branch = pull_request.get("base", {}).get("ref", "")
sender = body.get("sender", {}).get("login")
# logic to ignore PRs from specific users
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender:
if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
return False
# logic to ignore PRs with specific titles
if title:
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
if not isinstance(ignore_pr_title_re, list):
ignore_pr_title_re = [ignore_pr_title_re]
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
return False
# logic to ignore PRs with specific labels or source branches or target branches.
ignore_pr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
if pr_labels and ignore_pr_labels:
labels = [label['name'] for label in pr_labels]
if any(label in ignore_pr_labels for label in labels):
labels_str = ", ".join(labels)
get_logger().info(f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings")
return False
# logic to ignore PRs with specific source or target branches
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches):
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
get_logger().info(
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
return False
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
get_logger().info(
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
return False
except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
return True
async def handle_request(body: Dict[str, Any], event: str):
"""
Handle incoming GitHub webhook requests.
Args:
body: The request body.
event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.).
"""
action = body.get("action") # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
if not action:
return {}
agent = PRAgent()
log_context, sender, sender_id, sender_type = get_log_context(body, event, action, build_number)
# logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches
if is_bot_user(sender, sender_type) and 'check_run' not in body:
return {}
if action != 'created' and 'check_run' not in body:
if not should_process_pr_logic(body):
return {}
if 'check_run' in body: # handle failed checks
# get_logger().debug(f'Request body', artifact=body, event=event) # added inside handle_checks
pass
# handle comments on PRs
elif action == 'created':
get_logger().debug(f'Request body', artifact=body, event=event)
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent)
# handle new PRs
elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
get_logger().debug(f'Request body', artifact=body, event=event)
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent)
elif event == "issue_comment" and 'edited' in action:
pass # handle_checkbox_clicked
# handle pull_request event with synchronize action - "push trigger" for new commits
elif event == 'pull_request' and action == 'synchronize':
await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent)
elif event == 'pull_request' and action == 'closed':
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
handle_closed_pr(body, event, action, log_context)
else:
get_logger().info(f"event {event=} action {action=} does not require any handling")
return {}
def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str:
if not comment_body:
return ""
start_line = body["comment"]["start_line"]
end_line = body["comment"]["line"]
start_line = end_line if not start_line else start_line
question = comment_body.replace('/ask', '').strip()
diff_hunk = body["comment"]["diff_hunk"]
get_settings().set("ask_diff_hunk", diff_hunk)
path = body["comment"]["path"]
side = body["comment"]["side"]
comment_id = body["comment"]["id"]
if '/ask' in comment_body:
comment_body = f"/ask_line --line_start={start_line} --line_end={end_line} --side={side} --file_name={path} --comment_id={comment_id} {question}"
return comment_body
def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tuple[Dict[str, Any], str]:
invalid_result = {}, ""
pull_request = body.get("pull_request")
if not pull_request:
return invalid_result
api_url = pull_request.get("url")
if not api_url:
return invalid_result
log_context["api_url"] = api_url
if pull_request.get("draft", True) or pull_request.get("state") != "open":
return invalid_result
if action in ("review_requested", "synchronize") and pull_request.get("created_at") == pull_request.get("updated_at"):
# avoid double reviews when opening a PR for the first time
return invalid_result
return pull_request, api_url
async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body: dict, api_url: str,
log_context: dict):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
return
if not should_process_pr_logic(body): # Here we already updated the configuration with the repo settings
return {}
commands = get_settings().get(f"github_app.{commands_conf}")
if not commands:
get_logger().info(f"New PR, but no auto commands configured")
return
get_settings().set("config.is_auto_command", True)
for command in commands:
split_command = command.split(" ")
command = split_command[0]
args = split_command[1:]
other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args)
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}")
await agent.handle_request(api_url, new_command)
@router.get("/")
async def root():
return {"status": "ok"}
if get_settings().github_app.override_deployment_type:
# Override the deployment type to app
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
# get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router)
def start():
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
if __name__ == '__main__':
start()

View File

@ -0,0 +1,241 @@
import asyncio
import multiprocessing
import traceback
from collections import deque
from datetime import datetime, timezone
import aiohttp
import requests
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import get_git_provider
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
NOTIFICATION_URL = "https://api.github.com/notifications"
async def mark_notification_as_read(headers, notification, session):
async with session.patch(
f"https://api.github.com/notifications/threads/{notification['id']}",
headers=headers) as mark_read_response:
if mark_read_response.status != 205:
get_logger().error(
f"Failed to mark notification as read. Status code: {mark_read_response.status}")
def now() -> str:
"""
Get the current UTC time in ISO 8601 format.
Returns:
str: The current UTC time in ISO 8601 format.
"""
now_utc = datetime.now(timezone.utc).isoformat()
now_utc = now_utc.replace("+00:00", "Z")
return now_utc
async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
agent = PRAgent()
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
)
return success
def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider))
def process_comment_sync(pr_url, rest_of_comment, comment_id):
try:
# Run the async handle_request in a separate function
git_provider = get_git_provider()(pr_url=pr_url)
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
async def process_comment(pr_url, rest_of_comment, comment_id):
try:
git_provider = get_git_provider()(pr_url=pr_url)
git_provider.set_pr(pr_url)
agent = PRAgent()
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
)
get_logger().info(f"Finished processing comment for PR: {pr_url}")
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
async def is_valid_notification(notification, headers, handled_ids, session, user_id):
try:
if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url']
if not latest_comment or not isinstance(latest_comment, str):
get_logger().debug(f"no latest_comment")
return False, handled_ids
async with session.get(latest_comment, headers=headers) as comment_response:
check_prev_comments = False
user_tag = "@" + user_id
if comment_response.status == 200:
comment = await comment_response.json()
if 'id' in comment:
if comment['id'] in handled_ids:
get_logger().debug(f"comment['id'] in handled_ids")
return False, handled_ids
else:
handled_ids.add(comment['id'])
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
get_logger().debug(f"comment['user']['login'] == user_id")
check_prev_comments = True
comment_body = comment.get('body', '')
if not comment_body:
get_logger().debug(f"no comment_body")
check_prev_comments = True
else:
if user_tag not in comment_body:
get_logger().debug(f"user_tag not in comment_body")
check_prev_comments = True
else:
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
if not check_prev_comments:
return True, handled_ids, comment, comment_body, pr_url, user_tag
else: # we could not find the user tag in the latest comment. Check previous comments
# get all comments in the PR
requests_url = f"{pr_url}/comments".replace("pulls", "issues")
comments_response = requests.get(requests_url, headers=headers)
comments = comments_response.json()[::-1]
max_comment_to_scan = 4
for comment in comments[:max_comment_to_scan]:
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
continue
comment_body = comment.get('body', '')
if not comment_body:
continue
if user_tag in comment_body:
get_logger().info("found user tag in previous comments")
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
return True, handled_ids, comment, comment_body, pr_url, user_tag
get_logger().warning(f"Failed to fetch comments for PR: {pr_url}",
artifact={"comments": comments})
return False, handled_ids
return False, handled_ids
except Exception as e:
get_logger().exception(f"Error processing polling notification",
artifact={"notification": notification, "error": e})
return False, handled_ids
async def polling_loop():
"""
Polls for notifications and handles them accordingly.
"""
handled_ids = set()
since = [now()]
last_modified = [None]
git_provider = get_git_provider()()
user_id = git_provider.get_user_id()
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
get_settings().set("pr_description.publish_description_as_comment", True)
try:
deployment_type = get_settings().github.deployment_type
token = get_settings().github.user_token
except AttributeError:
deployment_type = 'none'
token = None
if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications")
if not token:
raise ValueError("User token must be set to get notifications")
async with aiohttp.ClientSession() as session:
while True:
try:
await asyncio.sleep(5)
headers = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"Bearer {token}"
}
params = {
"participating": "true"
}
if since[0]:
params["since"] = since[0]
if last_modified[0]:
headers["If-Modified-Since"] = last_modified[0]
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response:
if response.status == 200:
if 'Last-Modified' in response.headers:
last_modified[0] = response.headers['Last-Modified']
since[0] = None
notifications = await response.json()
if not notifications:
continue
get_logger().info(f"Received {len(notifications)} notifications")
task_queue = deque()
for notification in notifications:
if not notification:
continue
# mark notification as read
await mark_notification_as_read(headers, notification, session)
handled_ids.add(notification['id'])
output = await is_valid_notification(notification, headers, handled_ids, session, user_id)
if output[0]:
_, handled_ids, comment, comment_body, pr_url, user_tag = output
rest_of_comment = comment_body.split(user_tag)[1].strip()
comment_id = comment['id']
# Add to the task queue
get_logger().info(
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}")
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id)))
get_logger().info(f"Queued comment processing for PR: {pr_url}")
else:
get_logger().debug(f"Skipping comment processing for PR")
max_allowed_parallel_tasks = 10
if task_queue:
processes = []
for i, (func, args) in enumerate(task_queue): # Create parallel tasks
p = multiprocessing.Process(target=func, args=args)
processes.append(p)
p.start()
if i > max_allowed_parallel_tasks:
get_logger().error(
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session")
break
task_queue.clear()
# Dont wait for all processes to complete. Move on to the next iteration
# for p in processes:
# p.join()
elif response.status != 304:
print(f"Failed to fetch notifications. Status code: {response.status}")
except Exception as e:
get_logger().error(f"Polling exception during processing of a notification: {e}",
artifact={"traceback": traceback.format_exc()})
if __name__ == '__main__':
asyncio.run(polling_loop())

View File

@ -0,0 +1,288 @@
import copy
import json
import re
from datetime import datetime
import uvicorn
from fastapi import APIRouter, FastAPI, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from starlette.background import BackgroundTasks
from starlette.middleware import Middleware
from starlette_context import context
from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.config_loader import get_settings, global_settings
from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
from utils.pr_agent.secret_providers import get_secret_provider
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
router = APIRouter()
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
try:
import requests
headers = {
'Private-Token': f'{gitlab_token}'
}
# API endpoint to find MRs containing the commit
gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com')
response = requests.get(
f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests',
headers=headers
)
merge_requests = response.json()
if merge_requests and response.status_code == 200:
pr_url = merge_requests[0]['web_url']
return pr_url
else:
get_logger().info(f"No merge requests found for commit: {commit_sha}")
return None
except Exception as e:
get_logger().error(f"Failed to get MR url from commit sha: {e}")
return None
async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str):
log_context["action"] = body
log_context["event"] = "pull_request" if body == "/review" else "comment"
log_context["api_url"] = api_url
log_context["app_name"] = get_settings().get("CONFIG.APP_NAME", "Unknown")
with get_logger().contextualize(**log_context):
await PRAgent().handle_request(api_url, body)
async def _perform_commands_gitlab(commands_conf: str, agent: PRAgent, api_url: str,
log_context: dict, data: dict):
apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
return
if not should_process_pr_logic(data): # Here we already updated the configurations
return
commands = get_settings().get(f"gitlab.{commands_conf}", {})
get_settings().set("config.is_auto_command", True)
for command in commands:
try:
split_command = command.split(" ")
command = split_command[0]
args = split_command[1:]
other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args)
get_logger().info(f"Performing command: {new_command}")
with get_logger().contextualize(**log_context):
await agent.handle_request(api_url, new_command)
except Exception as e:
get_logger().error(f"Failed to perform command {command}: {e}")
def is_bot_user(data) -> bool:
try:
# logic to ignore bot users (unlike Github, no direct flag for bot users in gitlab)
sender_name = data.get("user", {}).get("name", "unknown").lower()
bot_indicators = ['codium', 'bot_', 'bot-', '_bot', '-bot']
if any(indicator in sender_name for indicator in bot_indicators):
get_logger().info(f"Skipping GitLab bot user: {sender_name}")
return True
except Exception as e:
get_logger().error(f"Failed 'is_bot_user' logic: {e}")
return False
def should_process_pr_logic(data) -> bool:
try:
if not data.get('object_attributes', {}):
return False
title = data['object_attributes'].get('title')
sender = data.get("user", {}).get("username", "")
# logic to ignore PRs from specific users
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender:
if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings")
return False
# logic to ignore MRs for titles, labels and source, target branches.
ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
ignore_mr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
ignore_mr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
#
if ignore_mr_source_branches:
source_branch = data['object_attributes'].get('source_branch')
if any(re.search(regex, source_branch) for regex in ignore_mr_source_branches):
get_logger().info(
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings")
return False
if ignore_mr_target_branches:
target_branch = data['object_attributes'].get('target_branch')
if any(re.search(regex, target_branch) for regex in ignore_mr_target_branches):
get_logger().info(
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings")
return False
if ignore_mr_labels:
labels = [label['title'] for label in data['object_attributes'].get('labels', [])]
if any(label in ignore_mr_labels for label in labels):
labels_str = ", ".join(labels)
get_logger().info(f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings")
return False
if ignore_mr_title:
if any(re.search(regex, title) for regex in ignore_mr_title):
get_logger().info(f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings")
return False
except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
return True
@router.post("/webhook")
async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
start_time = datetime.now()
request_json = await request.json()
context["settings"] = copy.deepcopy(global_settings)
async def inner(data: dict):
log_context = {"server_type": "gitlab_app"}
get_logger().debug("Received a GitLab webhook")
if request.headers.get("X-Gitlab-Token") and secret_provider:
request_token = request.headers.get("X-Gitlab-Token")
secret = secret_provider.get_secret(request_token)
if not secret:
get_logger().warning(f"Empty secret retrieved, request_token: {request_token}")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}))
try:
secret_dict = json.loads(secret)
gitlab_token = secret_dict["gitlab_token"]
log_context["token_id"] = secret_dict.get("token_name", secret_dict.get("id", "unknown"))
context["settings"].gitlab.personal_access_token = gitlab_token
except Exception as e:
get_logger().error(f"Failed to validate secret {request_token}: {e}")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
elif get_settings().get("GITLAB.SHARED_SECRET"):
secret = get_settings().get("GITLAB.SHARED_SECRET")
if not request.headers.get("X-Gitlab-Token") == secret:
get_logger().error("Failed to validate secret")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
else:
get_logger().error("Failed to validate secret")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token:
get_logger().error("No gitlab token found")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
get_logger().info("GitLab data", artifact=data)
sender = data.get("user", {}).get("username", "unknown")
sender_id = data.get("user", {}).get("id", "unknown")
# ignore bot users
if is_bot_user(data):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
if data.get('event_type') != 'note': # not a comment
# ignore MRs based on title, labels, source and target branches
if not should_process_pr_logic(data):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
log_context["sender"] = sender
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']:
title = data['object_attributes'].get('title')
url = data['object_attributes'].get('url')
draft = data['object_attributes'].get('draft')
get_logger().info(f"New merge request: {url}")
if draft:
get_logger().info(f"Skipping draft MR: {url}")
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
await _perform_commands_gitlab("pr_commands", PRAgent(), url, log_context, data)
elif data.get('object_kind') == 'note' and data.get('event_type') == 'note': # comment on MR
if 'merge_request' in data:
mr = data['merge_request']
url = mr.get('url')
get_logger().info(f"A comment has been added to a merge request: {url}")
body = data.get('object_attributes', {}).get('note')
if data.get('object_attributes', {}).get('type') == 'DiffNote' and '/ask' in body: # /ask_line
body = handle_ask_line(body, data)
await handle_request(url, body, log_context, sender_id)
elif data.get('object_kind') == 'push' and data.get('event_name') == 'push':
try:
project_id = data['project_id']
commit_sha = data['checkout_sha']
url = await get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id)
if not url:
get_logger().info(f"No MR found for commit: {commit_sha}")
return JSONResponse(status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}))
# we need first to apply_repo_settings
apply_repo_settings(url)
commands_on_push = get_settings().get(f"gitlab.push_commands", {})
handle_push_trigger = get_settings().get(f"gitlab.handle_push_trigger", False)
if not commands_on_push or not handle_push_trigger:
get_logger().info("Push event, but no push commands found or push trigger is disabled")
return JSONResponse(status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}))
get_logger().debug(f'A push event has been received: {url}')
await _perform_commands_gitlab("push_commands", PRAgent(), url, log_context, data)
except Exception as e:
get_logger().error(f"Failed to handle push event: {e}")
background_tasks.add_task(inner, request_json)
end_time = datetime.now()
get_logger().info(f"Processing time: {end_time - start_time}", request=request_json)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
def handle_ask_line(body, data):
try:
line_range_ = data['object_attributes']['position']['line_range']
# if line_range_['start']['type'] == 'new':
start_line = line_range_['start']['new_line']
end_line = line_range_['end']['new_line']
# else:
# start_line = line_range_['start']['old_line']
# end_line = line_range_['end']['old_line']
question = body.replace('/ask', '').strip()
path = data['object_attributes']['position']['new_path']
side = 'RIGHT' # if line_range_['start']['type'] == 'new' else 'LEFT'
comment_id = data['object_attributes']["discussion_id"]
get_logger().info("Handling line comment")
body = f"/ask_line --line_start={start_line} --line_end={end_line} --side={side} --file_name={path} --comment_id={comment_id} {question}"
except Exception as e:
get_logger().error(f"Failed to handle ask line comment: {e}")
return body
@router.get("/")
async def root():
return {"status": "ok"}
gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url:
raise ValueError("GITLAB.URL is not set")
get_settings().config.git_provider = "gitlab"
middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router)
def start():
uvicorn.run(app, host="0.0.0.0", port=3000)
if __name__ == '__main__':
start()

View File

@ -0,0 +1,191 @@
import multiprocessing
import os
# from prometheus_client import multiprocess
# Sample Gunicorn configuration file.
#
# Server socket
#
# bind - The socket to bind.
#
# A string of the form: 'HOST', 'HOST:PORT', 'unix:PATH'.
# An IP is a valid HOST.
#
# backlog - The number of pending connections. This refers
# to the number of clients that can be waiting to be
# served. Exceeding this number results in the client
# getting an error when attempting to connect. It should
# only affect servers under significant load.
#
# Must be a positive integer. Generally set in the 64-2048
# range.
#
# bind = '0.0.0.0:5000'
bind = '0.0.0.0:3000'
backlog = 2048
#
# Worker processes
#
# workers - The number of worker processes that this server
# should keep alive for handling requests.
#
# A positive integer generally in the 2-4 x $(NUM_CORES)
# range. You'll want to vary this a bit to find the best
# for your particular application's work load.
#
# worker_class - The type of workers to use. The default
# sync class should handle most 'normal' types of work
# loads. You'll want to read
# http://docs.gunicorn.org/en/latest/design.html#choosing-a-worker-type
# for information on when you might want to choose one
# of the other worker classes.
#
# A string referring to a Python path to a subclass of
# gunicorn.workers.base.Worker. The default provided values
# can be seen at
# http://docs.gunicorn.org/en/latest/settings.html#worker-class
#
# worker_connections - For the eventlet and gevent worker classes
# this limits the maximum number of simultaneous clients that
# a single process can handle.
#
# A positive integer generally set to around 1000.
#
# timeout - If a worker does not notify the master process in this
# number of seconds it is killed and a new worker is spawned
# to replace it.
#
# Generally set to thirty seconds. Only set this noticeably
# higher if you're sure of the repercussions for sync workers.
# For the non sync workers it just means that the worker
# process is still communicating and is not tied to the length
# of time required to handle a single request.
#
# keepalive - The number of seconds to wait for the next request
# on a Keep-Alive HTTP connection.
#
# A positive integer. Generally set in the 1-5 seconds range.
#
if os.getenv('GUNICORN_WORKERS', None):
workers = int(os.getenv('GUNICORN_WORKERS'))
else:
cores = multiprocessing.cpu_count()
workers = cores * 2 + 1
worker_connections = 1000
timeout = 240
keepalive = 2
#
# spew - Install a trace function that spews every line of Python
# that is executed when running the server. This is the
# nuclear option.
#
# True or False
#
spew = False
#
# Server mechanics
#
# daemon - Detach the main Gunicorn process from the controlling
# terminal with a standard fork/fork sequence.
#
# True or False
#
# raw_env - Pass environment variables to the execution environment.
#
# pidfile - The path to a pid file to write
#
# A path string or None to not write a pid file.
#
# user - Switch worker processes to run as this user.
#
# A valid user id (as an integer) or the name of a user that
# can be retrieved with a call to pwd.getpwnam(value) or None
# to not change the worker process user.
#
# group - Switch worker process to run as this group.
#
# A valid group id (as an integer) or the name of a user that
# can be retrieved with a call to pwd.getgrnam(value) or None
# to change the worker processes group.
#
# umask - A mask for file permissions written by Gunicorn. Note that
# this affects unix socket permissions.
#
# A valid value for the os.umask(mode) call or a string
# compatible with int(value, 0) (0 means Python guesses
# the base, so values like "0", "0xFF", "0022" are valid
# for decimal, hex, and octal representations)
#
# tmp_upload_dir - A directory to store temporary request data when
# requests are read. This will most likely be disappearing soon.
#
# A path to a directory where the process owner can write. Or
# None to signal that Python should choose one on its own.
#
daemon = False
raw_env = []
pidfile = None
umask = 0
user = None
group = None
tmp_upload_dir = None
#
# Logging
#
# logfile - The path to a log file to write to.
#
# A path string. "-" means log to stdout.
#
# loglevel - The granularity of log output
#
# A string of "debug", "info", "warning", "error", "critical"
#
errorlog = '-'
loglevel = 'info'
accesslog = None
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
#
# Process naming
#
# proc_name - A base to use with setproctitle to change the way
# that Gunicorn processes are reported in the system process
# table. This affects things like 'ps' and 'top'. If you're
# going to be running more than one instance of Gunicorn you'll
# probably want to set a name to tell them apart. This requires
# that you install the setproctitle module.
#
# A string or None to choose a default of something like 'gunicorn'.
#
proc_name = None
#
# Server hooks
#
# post_fork - Called just after a worker has been forked.
#
# A callable that takes a server and worker instance
# as arguments.
#
# pre_fork - Called just prior to forking the worker subprocess.
#
# A callable that accepts the same arguments as after_fork
#
# pre_exec - Called just prior to forking off a secondary
# master process during things like config reloading.
#
# A callable that takes a server instance as the sole argument.
#

View File

@ -0,0 +1,203 @@
class HelpMessage:
@staticmethod
def get_general_commands_text():
commands_text = "> - **/review**: Request a review of your Pull Request. \n" \
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n" \
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" \
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n" \
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n" \
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" \
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" \
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" \
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" \
">To list the possible configuration parameters, add a **/config** comment. \n"
return commands_text
@staticmethod
def get_general_bot_help_text():
output = f"> To invoke the PR-Agent, add a comment using one of the following commands: \n{HelpMessage.get_general_commands_text()} \n"
return output
@staticmethod
def get_review_usage_guide():
output ="**Overview:**\n"
output +=("The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n")
output +="""\
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template:
```
/review --pr_reviewer.some_config1=... --pr_reviewer.some_config2=...
```
- With a [configuration file](https://pr-agent-docs.codium.ai/usage-guide/configuration_options/), use the following template:
```
[pr_reviewer]
some_config1=...
some_config2=...
```
"""
output += f"\n\nSee the review [usage page](https://pr-agent-docs.codium.ai/tools/review/) for a comprehensive guide on using this tool.\n\n"
return output
@staticmethod
def get_describe_usage_guide():
output = "**Overview:**\n"
output += "The `describe` tool scans the PR code changes, and generates a description for the PR - title, type, summary, walkthrough and labels. "
output += "The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on a PR.\n"
output += """\
When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L46) related to the describe tool (`pr_description` section), use the following template:
```
/describe --pr_description.some_config1=... --pr_description.some_config2=...
```
With a [configuration file](https://pr-agent-docs.codium.ai/usage-guide/configuration_options/), use the following template:
```
[pr_description]
some_config1=...
some_config2=...
```
"""
output += "\n\n<table>"
# automation
output += "<tr><td><details> <summary><strong> Enabling\\disabling automation </strong></summary><hr>\n\n"
output += """\
- When you first install the app, the [default mode](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) for the describe tool is:
```
pr_commands = ["/describe", ...]
```
meaning the `describe` tool will run automatically on every PR.
- Markers are an alternative way to control the generated description, to give maximal control to the user. If you set:
```
pr_commands = ["/describe --pr_description.use_description_markers=true", ...]
```
the tool will replace every marker of the form `pr_agent:marker_name` in the PR description with the relevant content, where `marker_name` is one of the following:
- `type`: the PR type.
- `summary`: the PR summary.
- `walkthrough`: the PR walkthrough.
Note that when markers are enabled, if the original PR description does not contain any markers, the tool will not alter the description at all.
"""
output += "\n\n</details></td></tr>\n\n"
# custom labels
output += "<tr><td><details> <summary><strong> Custom labels </strong></summary><hr>\n\n"
output += """\
The default labels of the `describe` tool are quite generic: [`Bug fix`, `Tests`, `Enhancement`, `Documentation`, `Other`].
If you specify [custom labels](https://pr-agent-docs.codium.ai/tools/describe/#handle-custom-labels-from-the-repos-labels-page) in the repo's labels page or via configuration file, you can get tailored labels for your use cases.
Examples for custom labels:
- `Main topic:performance` - pr_agent:The main topic of this PR is performance
- `New endpoint` - pr_agent:A new endpoint was added in this PR
- `SQL query` - pr_agent:A new SQL query was added in this PR
- `Dockerfile changes` - pr_agent:The PR contains changes in the Dockerfile
- ...
The list above is eclectic, and aims to give an idea of different possibilities. Define custom labels that are relevant for your repo and use cases.
Note that Labels are not mutually exclusive, so you can add multiple label categories.
Make sure to provide proper title, and a detailed and well-phrased description for each label, so the tool will know when to suggest it.
"""
output += "\n\n</details></td></tr>\n\n"
# Inline File Walkthrough
output += "<tr><td><details> <summary><strong> Inline File Walkthrough 💎</strong></summary><hr>\n\n"
output += """\
For enhanced user experience, the `describe` tool can add file summaries directly to the "Files changed" tab in the PR page.
This will enable you to quickly understand the changes in each file, while reviewing the code changes (diffs).
To enable inline file summary, set `pr_description.inline_file_summary` in the configuration file, possible values are:
- `'table'`: File changes walkthrough table will be displayed on the top of the "Files changed" tab, in addition to the "Conversation" tab.
- `true`: A collapsable file comment with changes title and a changes summary for each file in the PR.
- `false` (default): File changes walkthrough will be added only to the "Conversation" tab.
"""
# extra instructions
output += "<tr><td><details> <summary><strong> Utilizing extra instructions</strong></summary><hr>\n\n"
output += '''\
The `describe` tool can be configured with extra instructions, to guide the model to a feedback tailored to the needs of your project.
Be specific, clear, and concise in the instructions. With extra instructions, you are the prompter. Notice that the general structure of the description is fixed, and cannot be changed. Extra instructions can change the content or style of each sub-section of the PR description.
Examples for extra instructions:
```
[pr_description]
extra_instructions="""\
- The PR title should be in the format: '<PR type>: <title>'
- The title should be short and concise (up to 10 words)
- ...
"""
```
Use triple quotes to write multi-line instructions. Use bullet points to make the instructions more readable.
'''
output += "\n\n</details></td></tr>\n\n"
# general
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
output += HelpMessage.get_general_bot_help_text()
output += "\n\n</details></td></tr>\n\n"
output += "</table>"
output += f"\n\nSee the [describe usage](https://pr-agent-docs.codium.ai/tools/describe/) page for a comprehensive guide on using this tool.\n\n"
return output
@staticmethod
def get_ask_usage_guide():
output = "**Overview:**\n"
output += """\
The `ask` tool answers questions about the PR, based on the PR code changes.
It can be invoked manually by commenting on any PR:
```
/ask "..."
```
Note that the tool does not have "memory" of previous questions, and answers each question independently.
You can ask questions about the entire PR, about specific code lines, or about an image related to the PR code changes.
"""
# output += "\n\n<table>"
#
# # # general
# # output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
# # output += HelpMessage.get_general_bot_help_text()
# # output += "\n\n</details></td></tr>\n\n"
#
# output += "</table>"
output += f"\n\nSee the [ask usage](https://pr-agent-docs.codium.ai/tools/ask/) page for a comprehensive guide on using this tool.\n\n"
return output
@staticmethod
def get_improve_usage_guide():
output = "**Overview:**\n"
output += "The code suggestions tool, named `improve`, scans the PR code changes, and automatically generates code suggestions for improving the PR."
output += "The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on a PR.\n"
output += """\
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L78) related to the improve tool (`pr_code_suggestions` section), use the following template:
```
/improve --pr_code_suggestions.some_config1=... --pr_code_suggestions.some_config2=...
```
- With a [configuration file](https://pr-agent-docs.codium.ai/usage-guide/configuration_options/), use the following template:
```
[pr_code_suggestions]
some_config1=...
some_config2=...
```
"""
output += f"\n\nSee the improve [usage page](https://pr-agent-docs.codium.ai/tools/improve/) for a comprehensive guide on using this tool.\n\n"
return output

View File

@ -0,0 +1,16 @@
from fastapi import FastAPI
from mangum import Mangum
from starlette.middleware import Middleware
from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.servers.github_app import router
middleware = [Middleware(RawContextMiddleware)]
app = FastAPI(middleware=middleware)
app.include_router(router)
handler = Mangum(app, lifespan="off")
def serverless(event, context):
return handler(event, context)

View File

@ -0,0 +1,86 @@
import hashlib
import hmac
import time
from collections import defaultdict
from typing import Any, Callable
from fastapi import HTTPException
def verify_signature(payload_body, secret_token, signature_header):
"""Verify that the payload was sent from GitHub by validating SHA256.
Raise and return 403 if not authorized.
Args:
payload_body: original request body to verify (request.body())
secret_token: GitHub app webhook token (WEBHOOK_SECRET)
signature_header: header received from GitHub (x-hub-signature-256)
"""
if not signature_header:
raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!")
hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256)
expected_signature = "sha256=" + hash_object.hexdigest()
if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded."""
pass
class DefaultDictWithTimeout(defaultdict):
"""A defaultdict with a time-to-live (TTL)."""
def __init__(
self,
default_factory: Callable[[], Any] = None,
ttl: int = None,
refresh_interval: int = 60,
update_key_time_on_get: bool = True,
*args,
**kwargs,
):
"""
Args:
default_factory: The default factory to use for keys that are not in the dictionary.
ttl: The time-to-live (TTL) in seconds.
refresh_interval: How often to refresh the dict and delete items older than the TTL.
update_key_time_on_get: Whether to update the access time of a key also on get (or only when set).
"""
super().__init__(default_factory, *args, **kwargs)
self.__key_times = dict()
self.__ttl = ttl
self.__refresh_interval = refresh_interval
self.__update_key_time_on_get = update_key_time_on_get
self.__last_refresh = self.__time() - self.__refresh_interval
@staticmethod
def __time():
return time.monotonic()
def __refresh(self):
if self.__ttl is None:
return
request_time = self.__time()
if request_time - self.__last_refresh > self.__refresh_interval:
return
to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl]
for key in to_delete:
del self[key]
self.__last_refresh = request_time
def __getitem__(self, __key):
if self.__update_key_time_on_get:
self.__key_times[__key] = self.__time()
self.__refresh()
return super().__getitem__(__key)
def __setitem__(self, __key, __value):
self.__key_times[__key] = self.__time()
return super().__setitem__(__key, __value)
def __delitem__(self, __key):
del self.__key_times[__key]
return super().__delitem__(__key)

View File

@ -0,0 +1,96 @@
# QUICKSTART:
# Copy this file to .secrets.toml in the same folder.
# The minimum workable settings - set openai.key to your API key.
# Set github.deployment_type to "user" and github.user_token to your GitHub personal access token.
# This will allow you to run the CLI scripts in the scripts/ folder and the github_polling server.
#
# See README for details about GitHub App deployment.
[openai]
key = "" # Acquire through https://platform.openai.com
#org = "<ORGANIZATION>" # Optional, may be commented out.
# Uncomment the following for Azure OpenAI
#api_type = "azure"
#api_version = '2023-05-15' # Check Azure documentation for the current API version
#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "" # The deployment name you chose when you deployed the engine
#fallback_deployments = [] # For each fallback model specified in configuration.toml in the [config] section, specify the appropriate deployment_id
[pinecone]
api_key = "..."
environment = "gcp-starter"
[anthropic]
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
[cohere]
key = "" # Optional, uncomment if you want to use Cohere. Acquire through https://dashboard.cohere.ai/
[replicate]
key = "" # Optional, uncomment if you want to use Replicate. Acquire through https://replicate.com/
[groq]
key = "" # Acquire through https://console.groq.com/keys
[huggingface]
key = "" # Optional, uncomment if you want to use Huggingface Inference API. Acquire through https://huggingface.co/docs/api-inference/quicktour
api_base = "" # the base url for your huggingface inference endpoint
[ollama]
api_base = "" # the base url for your local Llama 2, Code Llama, and other models inference endpoint. Acquire through https://ollama.ai/
[vertexai]
vertex_project = "" # the google cloud platform project name for your vertexai deployment
vertex_location = "" # the google cloud platform location for your vertexai deployment
[google_ai_studio]
gemini_api_key = "" # the google AI Studio API key
[github]
# ---- Set the following only for deployment type == "user"
user_token = "" # A GitHub personal access token with 'repo' scope.
deployment_type = "user" #set to user by default
# ---- Set the following only for deployment type == "app", see README for details.
private_key = """\
-----BEGIN RSA PRIVATE KEY-----
<GITHUB PRIVATE KEY>
-----END RSA PRIVATE KEY-----
"""
app_id = 123456 # The GitHub App ID, replace with your own.
webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.
[gitlab]
# Gitlab personal access token
personal_access_token = ""
shared_secret = "" # webhook secret
[bitbucket]
# For Bitbucket personal/repository bearer token
bearer_token = ""
[bitbucket_server]
# For Bitbucket Server bearer token
bearer_token = ""
webhook_secret = ""
# For Bitbucket app
app_key = ""
base_url = ""
[litellm]
LITELLM_TOKEN = "" # see https://docs.litellm.ai/docs/debugging/hosted_debugging for details and instructions on how to get a token
[azure_devops]
# For Azure devops personal access token
org = ""
pat = ""
[azure_devops_server]
# For Azure devops Server basic auth - configured in the webhook creation
# Optional, uncomment if you want to use Azure devops webhooks. Value assinged when you create the webhook
# webhook_username = "<basic auth user>"
# webhook_password = "<basic auth password>"
[deepseek]
key = ""

View File

@ -0,0 +1,333 @@
[config]
# models
model="o3-mini"
fallback_models=["o3-mini"]
# model_weak="gpt-4o-mini" # optional, a weaker model to use for some easier tasks
# CLI
git_provider="gitlab"
publish_output=true
publish_output_progress=true
publish_output_no_suggestions=true
verbosity_level=0 # 0,1,2
use_extra_bad_extensions=false
# Configurations
use_wiki_settings_file=true
use_repo_settings_file=true
use_global_settings_file=true
disable_auto_feedback = false
ai_timeout=120 # 2minutes
skip_keys = []
custom_reasoning_model = true # when true, disables system messages and temperature controls for models that don't support chat-style inputs
# token limits
max_description_tokens = 500
max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
custom_model_max_tokens=-1 # for models not in the default list
# patch extension logic
patch_extension_skip_types =[".md",".txt"]
allow_dynamic_context=true
max_extra_lines_before_dynamic_context = 8 # will try to include up to 10 extra lines before the hunk in the patch, until we reach an enclosing function or class
patch_extra_lines_before = 3 # Number of extra lines (+3 default ones) to include before each hunk in the patch
patch_extra_lines_after = 1 # Number of extra lines (+3 default ones) to include after each hunk in the patch
secret_provider=""
cli_mode=false
ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs
ai_disclaimer="" # Pro feature, full text for the AI disclaimer
output_relevant_configurations=false
large_patch_policy = "clip" # "clip", "skip"
duplicate_prompt_examples = false
# seed
seed=-1 # set positive value to fix the seed (and ensure temperature=0)
temperature=0.2
# ignore logic
ignore_pr_title = ["^\\[Auto\\]", "^Auto"] # a list of regular expressions to match against the PR title to ignore the PR agent
ignore_pr_target_branches = [] # a list of regular expressions of target branches to ignore from PR agent when an PR is created
ignore_pr_source_branches = [] # a list of regular expressions of source branches to ignore from PR agent when an PR is created
ignore_pr_labels = [] # labels to ignore from PR agent when an PR is created
ignore_pr_authors = [] # authors to ignore from PR agent when an PR is created
#
is_auto_command = false # will be auto-set to true if the command is triggered by an automation
enable_ai_metadata = false # will enable adding ai metadata
# auto approval 💎
enable_auto_approval=false # Set to true to enable auto-approval of PRs under certain conditions
auto_approve_for_low_review_effort=-1 # -1 to disable, [1-5] to set the threshold for auto-approval
auto_approve_for_no_suggestions=false # If true, the PR will be auto-approved if there are no suggestions
[pr_reviewer] # /review #
# enable/disable features
require_score_review=false
require_tests_review=true
require_estimate_effort_to_review=true
require_can_be_split_review=false
require_security_review=true
require_ticket_analysis_review=true
# general options
persistent_comment=true
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
final_update_message = true
# review labels
enable_review_labels_security=true
enable_review_labels_effort=true
# specific configurations for incremental review (/review -i)
require_all_thresholds_for_incremental_review=false
minimal_commits_for_incremental_review=0
minimal_minutes_for_incremental_review=0
enable_intro_text=true
enable_help_text=false # Determines whether to include help text in the PR review. Enabled by default.
[pr_description] # /describe #
publish_labels=false
add_original_user_description=true
generate_ai_title=false
use_bullet_points=true
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
enable_pr_type=true
final_update_message = true
enable_help_text=false
enable_help_comment=true
# describe as comment
publish_description_as_comment=false
publish_description_as_comment_persistent=true
## changes walkthrough section
enable_semantic_files_types=true
collapsible_file_list='adaptive' # true, false, 'adaptive'
collapsible_file_list_threshold=8
inline_file_summary=false # false, true, 'table'
# markers
use_description_markers=false
include_generated_by_header=true
# large pr mode 💎
enable_large_pr_handling=true
max_ai_calls=4
async_ai_calls=true
#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']
[pr_questions] # /ask #
enable_help_text=false
[pr_code_suggestions] # /improve #
max_context_tokens=16000
#
commitable_code_suggestions = false
dual_publishing_score_threshold=-1 # -1 to disable, [0-10] to set the threshold (>=) for publishing a code suggestion both in a table and as commitable
focus_only_on_problems=true
#
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
enable_help_text=false
enable_chat_text=false
enable_intro_text=true
persistent_comment=true
max_history_len=4
# enable to apply suggestion 💎
apply_suggestions_checkbox=true
# suggestions scoring
suggestions_score_threshold=0 # [0-10]| recommend not to set this value above 8, since above it may clip highly relevant suggestions
new_score_mechanism=true
new_score_mechanism_th_high=9
new_score_mechanism_th_medium=7
# params for '/improve --extended' mode
auto_extended_mode=true
num_code_suggestions_per_chunk=4
max_number_of_calls = 3
parallel_calls = true
final_clip_factor = 0.8
# self-review checkbox
demand_code_suggestions_self_review=false # add a checkbox for the author to self-review the code suggestions
code_suggestions_self_review_text= "**作者自我审查**我已审核PR代码建议并处理了相关内容."
approve_pr_on_self_review=false # Pro feature. if true, the PR will be auto-approved after the author clicks on the self-review checkbox
fold_suggestions_on_self_review=true # Pro feature. if true, the code suggestions will be folded after the author clicks on the self-review checkbox
# Suggestion impact 💎
publish_post_process_suggestion_impact=true
wiki_page_accepted_suggestions=true
allow_thumbs_up_down=false
[pr_custom_prompt] # /custom_prompt #
prompt = """\
The code suggestions should focus only on the following:
- ...
- ...
...
"""
suggestions_score_threshold=0
num_code_suggestions_per_chunk=4
self_reflect_on_custom_suggestions=true
enable_help_text=false
[pr_add_docs] # /add_docs #
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
docs_style = "Sphinx" # "Google Style with Args, Returns, Attributes...etc", "Numpy Style", "Sphinx Style", "PEP257", "reStructuredText"
file = "" # in case there are several components with the same name, you can specify the relevant file
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
[pr_update_changelog] # /update_changelog #
push_changelog_changes=false
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
add_pr_link=true
[pr_analyze] # /analyze #
enable_help_text=true
[pr_test] # /test #
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
testing_framework = "" # specify the testing framework you want to use
num_tests=3 # number of tests to generate. max 5.
avoid_mocks=true # if true, the generated tests will prefer to use real objects instead of mocks
file = "" # in case there are several components with the same name, you can specify the relevant file
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
enable_help_text=false
[pr_improve_component] # /improve_component #
num_code_suggestions=4
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
file = "" # in case there are several components with the same name, you can specify the relevant file
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
[checks] # /checks (pro feature) #
enable_auto_checks_feedback=true
excluded_checks_list=["lint"] # list of checks to exclude, for example: ["check1", "check2"]
persistent_comment=true
enable_help_text=true
final_update_message = false
[pr_help] # /help #
force_local_db=false
num_retrieved_snippets=5
[pr_config] # /config #
[github]
# The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user"
ratelimit_retries = 5
base_url = "https://api.github.com"
publish_inline_comments_fallback_with_verification = true
try_fix_invalid_inline_comments = true
app_name = "pr-agent"
ignore_bot_pr = true
[github_action_config]
# auto_review = true # set as env var in .github/workflows/pr-agent.yaml
# auto_describe = true # set as env var in .github/workflows/pr-agent.yaml
# auto_improve = true # set as env var in .github/workflows/pr-agent.yaml
# pr_actions = ['opened', 'reopened', 'ready_for_review', 'review_requested']
[github_app]
# these toggles allows running the github app from custom deployments
bot_user = "github-actions[bot]"
override_deployment_type = true
# settings for "pull_request" event
handle_pr_actions = ['opened', 'reopened', 'ready_for_review']
pr_commands = [
"/describe --pr_description.final_update_message=false",
"/review",
"/improve",
]
# settings for "pull_request" event with "synchronize" action - used to detect and handle push triggers for new commits
handle_push_trigger = false
push_trigger_ignore_bot_commits = true
push_trigger_ignore_merge_commits = true
push_trigger_wait_for_initial_review = true
push_trigger_pending_tasks_backlog = true
push_trigger_pending_tasks_ttl = 300
push_commands = [
"/describe",
"/review",
]
[gitlab]
url = "http://192.168.1.91:4010"
pr_commands = [
"/describe --pr_description.final_update_message=false",
"/review",
"/improve",
]
handle_push_trigger = true
push_commands = [
"/describe",
"/review",
]
[bitbucket_app]
pr_commands = [
"/describe --pr_description.final_update_message=false",
"/review",
"/improve --pr_code_suggestions.commitable_code_suggestions=true",
]
avoid_full_files = false
[local]
# LocalGitProvider settings - uncomment to use paths other than default
# description_path= "path/to/description.md"
# review_path= "path/to/review.md"
[gerrit]
# endpoint to the gerrit service
# url = "ssh://gerrit.example.com:29418"
# user for gerrit authentication
# user = "ai-reviewer"
# patch server where patches will be saved
# patch_server_endpoint = "http://127.0.0.1:5000/patch"
# token to authenticate in the patch server
# patch_server_token = ""
[bitbucket_server]
# URL to the BitBucket Server instance
# url = "https://git.bitbucket.com"
url = ""
pr_commands = [
"/describe --pr_description.final_update_message=false",
"/review",
"/improve --pr_code_suggestions.commitable_code_suggestions=true",
]
[litellm]
# use_client = false
# drop_params = false
enable_callbacks = false
success_callback = []
failure_callback = []
service_callback = []
[pr_similar_issue]
skip_comments = false
force_update_dataset = false
max_issues_to_scan = 500
vectordb = "pinecone"
[pr_find_similar_component]
class_name = ""
file = ""
search_from_org = false
allow_fallback_less_words = true
number_of_keywords = 5
number_of_results = 5
[pinecone]
# fill and place in .secrets.toml
#api_key = ...
# environment = "gcp-starter"
[lancedb]
uri = "./lancedb"
[best_practices]
content = ""
organization_name = ""
max_lines_allowed = 800
enable_global_best_practices = false
[auto_best_practices]
enable_auto_best_practices = true # public - general flag to disable all auto best practices usage
utilize_auto_best_practices = true # public - disable usage of auto best practices in the 'improve' tool
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号" # public - extra instructions to the auto best practices generation prompt
content = ""
max_patterns = 5 # max number of patterns to be detected
[openai]
api_base = "http://110.40.24.85:3000/v1"
[llm]
model = "o3-mini"
force_config = true

View File

@ -0,0 +1,16 @@
[config]
enable_custom_labels=false
## template for custom labels
#[custom_labels."Bug fix"]
#description = """Fixes a bug in the code"""
#[custom_labels."Tests"]
#description = """Adds or modifies tests"""
#[custom_labels."Bug fix with tests"]
#description = """Fixes a bug in the code and adds or modifies tests"""
#[custom_labels."Enhancement"]
#description = """Adds new features or modifies existing ones"""
#[custom_labels."Documentation"]
#description = """Adds or modifies documentation"""
#[custom_labels."Other"]
#description = """Other changes that do not fit in any of the above categories"""

View File

@ -0,0 +1,12 @@
[ignore]
glob = [
# Ignore files and directories matching these glob patterns.
# See https://docs.python.org/3/library/glob.html
'vendor/**',
]
regex = [
# Ignore files and directories matching these regex patterns.
# See https://learnbyexample.github.io/python-regex-cheatsheet/
# for example: regex = ['.*\.toml$']
]

View File

@ -0,0 +1,440 @@
[bad_extensions]
default = [
'app',
'bin',
'bmp',
'bz2',
'class',
'csv',
'dat',
'db',
'dll',
'dylib',
'egg',
'eot',
'exe',
'gif',
'gitignore',
'glif',
'gradle',
'gz',
'ico',
'jar',
'jpeg',
'jpg',
'lo',
'lock',
'log',
'mp3',
'mp4',
'nar',
'o',
'ogg',
'otf',
'p',
'pdf',
'png',
'pickle',
'pkl',
'pyc',
'pyd',
'pyo',
'rkt',
'so',
'ss',
'svg',
'tar',
'tgz',
'tsv',
'ttf',
'war',
'webm',
'woff',
'woff2',
'xz',
'zip',
'zst',
'snap',
'lockb'
]
extra = [
'md',
'txt'
]
[language_extension_map_org]
"1C Enterprise" = ["*.bsl", ]
ABAP = [".abap", ]
"AGS Script" = [".ash", ]
AMPL = [".ampl", ]
ANTLR = [".g4", ]
"API Blueprint" = [".apib", ]
APL = [".apl", ".dyalog", ]
ASP = [".asp", ".asax", ".ascx", ".ashx", ".asmx", ".aspx", ".axd", ]
ATS = [".dats", ".hats", ".sats", ]
ActionScript = [".as", ]
Ada = [".adb", ".ada", ".ads", ]
Agda = [".agda", ]
Alloy = [".als", ]
ApacheConf = [".apacheconf", ".vhost", ]
AppleScript = [".applescript", ".scpt", ]
Arc = [".arc", ]
Arduino = [".ino", ]
AsciiDoc = [".asciidoc", ".adoc", ]
AspectJ = [".aj", ]
Assembly = [".asm", ".a51", ".nasm", ]
Augeas = [".aug", ]
AutoHotkey = [".ahk", ".ahkl", ]
AutoIt = [".au3", ]
Awk = [".awk", ".auk", ".gawk", ".mawk", ".nawk", ]
Batchfile = [".bat", ".cmd", ]
Befunge = [".befunge", ]
Bison = [".bison", ]
BitBake = [".bb", ]
BlitzBasic = [".decls", ]
BlitzMax = [".bmx", ]
Bluespec = [".bsv", ]
Boo = [".boo", ]
Brainfuck = [".bf", ]
Brightscript = [".brs", ]
Bro = [".bro", ]
C = [".c", ".cats", ".h", ".idc", ".w", ]
"C#" = [".cs", ".cake", ".cshtml", ".csx", ]
"C++" = [".cpp", ".c++", ".cc", ".cp", ".cxx", ".h++", ".hh", ".hpp", ".hxx", ".inl", ".ipp", ".tcc", ".tpp", ".C", ".H", ]
C-ObjDump = [".c-objdump", ]
"C2hs Haskell" = [".chs", ]
CLIPS = [".clp", ]
CMake = [".cmake", ".cmake.in", ]
COBOL = [".cob", ".cbl", ".ccp", ".cobol", ".cpy", ]
CSS = [".css", ]
CSV = [".csv", ]
"Cap'n Proto" = [".capnp", ]
CartoCSS = [".mss", ]
Ceylon = [".ceylon", ]
Chapel = [".chpl", ]
ChucK = [".ck", ]
Cirru = [".cirru", ]
Clarion = [".clw", ]
Clean = [".icl", ".dcl", ]
Click = [".click", ]
Clojure = [".clj", ".boot", ".cl2", ".cljc", ".cljs", ".cljs.hl", ".cljscm", ".cljx", ".hic", ]
CoffeeScript = [".coffee", "._coffee", ".cjsx", ".cson", ".iced", ]
ColdFusion = [".cfm", ".cfml", ]
"ColdFusion CFC" = [".cfc", ]
"Common Lisp" = [".lisp", ".asd", ".lsp", ".ny", ".podsl", ".sexp", ]
"Component Pascal" = [".cps", ]
Coq = [".coq", ]
Cpp-ObjDump = [".cppobjdump", ".c++-objdump", ".c++objdump", ".cpp-objdump", ".cxx-objdump", ]
Creole = [".creole", ]
Crystal = [".cr", ]
Csound = [".csd", ]
Cucumber = [".feature", ]
Cuda = [".cu", ".cuh", ]
Cycript = [".cy", ]
Cython = [".pyx", ".pxd", ".pxi", ]
D = [".di", ]
D-ObjDump = [".d-objdump", ]
"DIGITAL Command Language" = [".com", ]
DM = [".dm", ]
"DNS Zone" = [".zone", ".arpa", ]
"Darcs Patch" = [".darcspatch", ".dpatch", ]
Dart = [".dart", ]
Diff = [".diff", ".patch", ]
Dockerfile = [".dockerfile", "Dockerfile", ]
Dogescript = [".djs", ]
Dylan = [".dylan", ".dyl", ".intr", ".lid", ]
E = [".E", ]
ECL = [".ecl", ".eclxml", ]
Eagle = [".sch", ".brd", ]
"Ecere Projects" = [".epj", ]
Eiffel = [".e", ]
Elixir = [".ex", ".exs", ]
Elm = [".elm", ]
"Emacs Lisp" = [".el", ".emacs", ".emacs.desktop", ]
EmberScript = [".em", ".emberscript", ]
Erlang = [".erl", ".escript", ".hrl", ".xrl", ".yrl", ]
"F#" = [".fs", ".fsi", ".fsx", ]
FLUX = [".flux", ]
FORTRAN = [".f90", ".f", ".f03", ".f08", ".f77", ".f95", ".for", ".fpp", ]
Factor = [".factor", ]
Fancy = [".fy", ".fancypack", ]
Fantom = [".fan", ]
Formatted = [".eam.fs", ]
Forth = [".fth", ".4th", ".forth", ".frt", ]
FreeMarker = [".ftl", ]
G-code = [".g", ".gco", ".gcode", ]
GAMS = [".gms", ]
GAP = [".gap", ".gi", ]
GAS = [".s", ]
GDScript = [".gd", ]
GLSL = [".glsl", ".fp", ".frag", ".frg", ".fsh", ".fshader", ".geo", ".geom", ".glslv", ".gshader", ".shader", ".vert", ".vrx", ".vsh", ".vshader", ]
Genshi = [".kid", ]
"Gentoo Ebuild" = [".ebuild", ]
"Gentoo Eclass" = [".eclass", ]
"Gettext Catalog" = [".po", ".pot", ]
Glyph = [".glf", ]
Gnuplot = [".gp", ".gnu", ".gnuplot", ".plot", ".plt", ]
Go = [".go", ]
Golo = [".golo", ]
Gosu = [".gst", ".gsx", ".vark", ]
Grace = [".grace", ]
Gradle = [".gradle", ]
"Grammatical Framework" = [".gf", ]
GraphQL = [".graphql", ]
"Graphviz (DOT)" = [".dot", ".gv", ]
Groff = [".man", ".1", ".1in", ".1m", ".1x", ".2", ".3", ".3in", ".3m", ".3qt", ".3x", ".4", ".5", ".6", ".7", ".8", ".9", ".me", ".rno", ".roff", ]
Groovy = [".groovy", ".grt", ".gtpl", ".gvy", ]
"Groovy Server Pages" = [".gsp", ]
HCL = [".hcl", ".tf", ]
HLSL = [".hlsl", ".fxh", ".hlsli", ]
HTML = [".html", ".htm", ".html.hl", ".xht", ".xhtml", ]
"HTML+Django" = [".mustache", ".jinja", ]
"HTML+EEX" = [".eex", ]
"HTML+ERB" = [".erb", ".erb.deface", ]
"HTML+PHP" = [".phtml", ]
HTTP = [".http", ]
Haml = [".haml", ".haml.deface", ]
Handlebars = [".handlebars", ".hbs", ]
Harbour = [".hb", ]
Haskell = [".hs", ".hsc", ]
Haxe = [".hx", ".hxsl", ]
Hy = [".hy", ]
IDL = [".dlm", ]
"IGOR Pro" = [".ipf", ]
INI = [".ini", ".cfg", ".prefs", ".properties", ]
"IRC log" = [".irclog", ".weechatlog", ]
Idris = [".idr", ".lidr", ]
"Inform 7" = [".ni", ".i7x", ]
"Inno Setup" = [".iss", ]
Io = [".io", ]
Ioke = [".ik", ]
Isabelle = [".thy", ]
J = [".ijs", ]
JFlex = [".flex", ".jflex", ]
JSON = [".json", ".geojson", ".lock", ".topojson", ]
JSON5 = [".json5", ]
JSONLD = [".jsonld", ]
JSONiq = [".jq", ]
JSX = [".jsx", ]
Jade = [".jade", ]
Jasmin = [".j", ]
Java = [".java", ]
"Java Server Pages" = [".jsp", ]
JavaScript = [".js", "._js", ".bones", ".es6", ".jake", ".jsb", ".jscad", ".jsfl", ".jsm", ".jss", ".njs", ".pac", ".sjs", ".ssjs", ".xsjs", ".xsjslib", ]
Julia = [".jl", ]
"Jupyter Notebook" = [".ipynb", ]
KRL = [".krl", ]
KiCad = [".kicad_pcb", ]
Kit = [".kit", ]
Kotlin = [".kt", ".ktm", ".kts", ]
LFE = [".lfe", ]
LLVM = [".ll", ]
LOLCODE = [".lol", ]
LSL = [".lsl", ".lslp", ]
LabVIEW = [".lvproj", ]
Lasso = [".lasso", ".las", ".lasso8", ".lasso9", ".ldml", ]
Latte = [".latte", ]
Lean = [".lean", ".hlean", ]
Less = [".less", ]
Lex = [".lex", ]
LilyPond = [".ly", ".ily", ]
"Linker Script" = [".ld", ".lds", ]
Liquid = [".liquid", ]
"Literate Agda" = [".lagda", ]
"Literate CoffeeScript" = [".litcoffee", ]
"Literate Haskell" = [".lhs", ]
LiveScript = [".ls", "._ls", ]
Logos = [".xm", ".x", ".xi", ]
Logtalk = [".lgt", ".logtalk", ]
LookML = [".lookml", ]
Lua = [".lua", ".nse", ".pd_lua", ".rbxs", ".wlua", ]
M = [".mumps", ]
M4 = [".m4", ]
MAXScript = [".mcr", ]
MTML = [".mtml", ]
MUF = [".muf", ]
Makefile = [".mak", ".mk", ".mkfile", "Makefile", ]
Mako = [".mako", ".mao", ]
Maple = [".mpl", ]
Markdown = [".md", ".markdown", ".mkd", ".mkdn", ".mkdown", ".ron", ]
Mask = [".mask", ]
Mathematica = [".mathematica", ".cdf", ".ma", ".mt", ".nb", ".nbp", ".wl", ".wlt", ]
Matlab = [".matlab", ]
Max = [".maxpat", ".maxhelp", ".maxproj", ".mxt", ".pat", ]
MediaWiki = [".mediawiki", ".wiki", ]
Metal = [".metal", ]
MiniD = [".minid", ]
Mirah = [".druby", ".duby", ".mir", ".mirah", ]
Modelica = [".mo", ]
"Module Management System" = [".mms", ".mmk", ]
Monkey = [".monkey", ]
MoonScript = [".moon", ]
Myghty = [".myt", ]
NSIS = [".nsi", ".nsh", ]
NetLinx = [".axs", ".axi", ]
"NetLinx+ERB" = [".axs.erb", ".axi.erb", ]
NetLogo = [".nlogo", ]
Nginx = [".nginxconf", ]
Nimrod = [".nim", ".nimrod", ]
Ninja = [".ninja", ]
Nit = [".nit", ]
Nix = [".nix", ]
Nu = [".nu", ]
NumPy = [".numpy", ".numpyw", ".numsc", ]
OCaml = [".ml", ".eliom", ".eliomi", ".ml4", ".mli", ".mll", ".mly", ]
ObjDump = [".objdump", ]
"Objective-C++" = [".mm", ]
Objective-J = [".sj", ]
Octave = [".oct", ]
Omgrofl = [".omgrofl", ]
Opa = [".opa", ]
Opal = [".opal", ]
OpenCL = [".cl", ".opencl", ]
"OpenEdge ABL" = [".p", ]
OpenSCAD = [".scad", ]
Org = [".org", ]
Ox = [".ox", ".oxh", ".oxo", ]
Oxygene = [".oxygene", ]
Oz = [".oz", ]
PAWN = [".pwn", ]
PHP = [".php", ".aw", ".ctp", ".php3", ".php4", ".php5", ".phps", ".phpt", ]
"POV-Ray SDL" = [".pov", ]
Pan = [".pan", ]
Papyrus = [".psc", ]
Parrot = [".parrot", ]
"Parrot Assembly" = [".pasm", ]
"Parrot Internal Representation" = [".pir", ]
Pascal = [".pas", ".dfm", ".dpr", ".lpr", ]
Perl = [".pl", ".al", ".perl", ".ph", ".plx", ".pm", ".psgi", ".t", ]
Perl6 = [".6pl", ".6pm", ".nqp", ".p6", ".p6l", ".p6m", ".pl6", ".pm6", ]
Pickle = [".pkl", ]
PigLatin = [".pig", ]
Pike = [".pike", ".pmod", ]
Pod = [".pod", ]
PogoScript = [".pogo", ]
Pony = [".pony", ]
PostScript = [".ps", ".eps", ]
PowerShell = [".ps1", ".psd1", ".psm1", ]
Processing = [".pde", ]
Prolog = [".prolog", ".yap", ]
"Propeller Spin" = [".spin", ]
"Protocol Buffer" = [".proto", ]
"Public Key" = [".pub", ]
"Pure Data" = [".pd", ]
PureBasic = [".pb", ".pbi", ]
PureScript = [".purs", ]
Python = [".py", ".bzl", ".gyp", ".lmi", ".pyde", ".pyp", ".pyt", ".pyw", ".tac", ".wsgi", ".xpy", ]
"Python traceback" = [".pytb", ]
QML = [".qml", ".qbs", ]
QMake = [".pri", ]
R = [".r", ".rd", ".rsx", ]
RAML = [".raml", ]
RDoc = [".rdoc", ]
REALbasic = [".rbbas", ".rbfrm", ".rbmnu", ".rbres", ".rbtbar", ".rbuistate", ]
RHTML = [".rhtml", ]
RMarkdown = [".rmd", ]
Racket = [".rkt", ".rktd", ".rktl", ".scrbl", ]
"Ragel in Ruby Host" = [".rl", ]
"Raw token data" = [".raw", ]
Rebol = [".reb", ".r2", ".r3", ".rebol", ]
Red = [".red", ".reds", ]
Redcode = [".cw", ]
"Ren'Py" = [".rpy", ]
RenderScript = [".rsh", ]
RobotFramework = [".robot", ]
Rouge = [".rg", ]
Ruby = [".rb", ".builder", ".gemspec", ".god", ".irbrc", ".jbuilder", ".mspec", ".podspec", ".rabl", ".rake", ".rbuild", ".rbw", ".rbx", ".ru", ".ruby", ".thor", ".watchr", ]
Rust = [".rs", ".rs.in", ]
SAS = [".sas", ]
SCSS = [".scss", ]
SMT = [".smt2", ".smt", ]
SPARQL = [".sparql", ".rq", ]
SQF = [".sqf", ".hqf", ]
SQL = [".pls", ".pck", ".pkb", ".pks", ".plb", ".plsql", ".sql", ".cql", ".ddl", ".prc", ".tab", ".udf", ".viw", ".db2", ]
STON = [".ston", ]
SVG = [".svg", ]
Sage = [".sage", ".sagews", ]
SaltStack = [".sls", ]
Sass = [".sass", ]
Scala = [".scala", ".sbt", ]
Scaml = [".scaml", ]
Scheme = [".scm", ".sld", ".sps", ".ss", ]
Scilab = [".sci", ".sce", ]
Self = [".self", ]
Shell = [".sh", ".bash", ".bats", ".command", ".ksh", ".sh.in", ".tmux", ".tool", ".zsh", ]
ShellSession = [".sh-session", ]
Shen = [".shen", ]
Slash = [".sl", ]
Slim = [".slim", ]
Smali = [".smali", ]
Smalltalk = [".st", ]
Smarty = [".tpl", ]
Solidity = [".sol", ]
SourcePawn = [".sp", ".sma", ]
Squirrel = [".nut", ]
Stan = [".stan", ]
"Standard ML" = [".ML", ".fun", ".sig", ".sml", ]
Stata = [".do", ".ado", ".doh", ".ihlp", ".mata", ".matah", ".sthlp", ]
Stylus = [".styl", ]
SuperCollider = [".scd", ]
Swift = [".swift", ]
SystemVerilog = [".sv", ".svh", ".vh", ]
TOML = [".toml", ]
TXL = [".txl", ]
Tcl = [".tcl", ".adp", ".tm", ]
Tcsh = [".tcsh", ".csh", ]
TeX = [".tex", ".aux", ".bbx", ".bib", ".cbx", ".dtx", ".ins", ".lbx", ".ltx", ".mkii", ".mkiv", ".mkvi", ".sty", ".toc", ]
Tea = [".tea", ]
Text = [".txt", ".no", ]
Textile = [".textile", ]
Thrift = [".thrift", ]
Turing = [".tu", ]
Turtle = [".ttl", ]
Twig = [".twig", ]
TypeScript = [".ts", ".tsx", ]
"Unified Parallel C" = [".upc", ]
"Unity3D Asset" = [".anim", ".asset", ".mat", ".meta", ".prefab", ".unity", ]
Uno = [".uno", ]
UnrealScript = [".uc", ]
UrWeb = [".ur", ".urs", ]
VCL = [".vcl", ]
VHDL = [".vhdl", ".vhd", ".vhf", ".vhi", ".vho", ".vhs", ".vht", ".vhw", ]
Vala = [".vala", ".vapi", ]
Verilog = [".veo", ]
VimL = [".vim", ]
"Visual Basic" = [".vb", ".bas", ".frm", ".frx", ".vba", ".vbhtml", ".vbs", ]
Volt = [".volt", ]
Vue = [".vue", ]
"Web Ontology Language" = [".owl", ]
WebAssembly = [".wat", ]
WebIDL = [".webidl", ]
X10 = [".x10", ]
XC = [".xc", ]
XML = [".xml", ".ant", ".axml", ".ccxml", ".clixml", ".cproject", ".csl", ".csproj", ".ct", ".dita", ".ditamap", ".ditaval", ".dll.config", ".dotsettings", ".filters", ".fsproj", ".fxml", ".glade", ".grxml", ".iml", ".ivy", ".jelly", ".jsproj", ".kml", ".launch", ".mdpolicy", ".mxml", ".nproj", ".nuspec", ".odd", ".osm", ".plist", ".props", ".ps1xml", ".psc1", ".pt", ".rdf", ".rss", ".scxml", ".srdf", ".storyboard", ".stTheme", ".sublime-snippet", ".targets", ".tmCommand", ".tml", ".tmLanguage", ".tmPreferences", ".tmSnippet", ".tmTheme", ".ui", ".urdf", ".ux", ".vbproj", ".vcxproj", ".vssettings", ".vxml", ".wsdl", ".wsf", ".wxi", ".wxl", ".wxs", ".x3d", ".xacro", ".xaml", ".xib", ".xlf", ".xliff", ".xmi", ".xml.dist", ".xproj", ".xsd", ".xul", ".zcml", ]
XPages = [".xsp-config", ".xsp.metadata", ]
XProc = [".xpl", ".xproc", ]
XQuery = [".xquery", ".xq", ".xql", ".xqm", ".xqy", ]
XS = [".xs", ]
XSLT = [".xslt", ".xsl", ]
Xojo = [".xojo_code", ".xojo_menu", ".xojo_report", ".xojo_script", ".xojo_toolbar", ".xojo_window", ]
Xtend = [".xtend", ]
YAML = [".yml", ".reek", ".rviz", ".sublime-syntax", ".syntax", ".yaml", ".yaml-tmlanguage", ]
YANG = [".yang", ]
Yacc = [".y", ".yacc", ".yy", ]
Zephir = [".zep", ]
Zig = [".zig", ]
Zimpl = [".zimpl", ".zmpl", ".zpl", ]
desktop = [".desktop", ".desktop.in", ]
eC = [".ec", ".eh", ]
edn = [".edn", ]
fish = [".fish", ]
mupad = [".mu", ]
nesC = [".nc", ]
ooc = [".ooc", ]
reStructuredText = [".rst", ".rest", ".rest.txt", ".rst.txt", ]
wisp = [".wisp", ]
xBase = [".prg", ".prw", ]
[docs_blacklist_extensions]
# Disable docs for these extensions of text files and scripts that are not programming languages of function, classes and methods
docs_blacklist = ['sql', 'txt', 'yaml', 'json', 'xml', 'md', 'rst', 'rest', 'rest.txt', 'rst.txt', 'mdpolicy', 'mdown', 'markdown', 'mdwn', 'mkd', 'mkdn', 'mkdown', 'sh']

View File

@ -0,0 +1,126 @@
[pr_add_docs_prompt]
system=""" PR-Doc, Pull Request (PR) .
PR Diff {{ docs_for_language }}.
PR Diff :
======
## file: 'src/file1.py'
@@ -12,3 +12,4 @@ def func1():
__new hunk__
12 1 PR
14 +PR 1
15 +PR 2
16 2 PR
__old hunk__
1 PR
-PR
2 PR
@@ ... @@ def func2():
__new hunk__
...
__old hunk__
...
## file: 'src/file2.py'
...
======
:
- / (//...), {{ docs_for_language }}.
- PR ( {{ language }} ) , {{ docs_for_language }}.
- '__new hunk__' . , header body.
- {{ docs_for_language }} {{ language }} {{ docs_for_language }} .
- {{ docs_for_language }} .
- {{ docs_for_language }} ().
{%- if extra_instructions %}
:
======
{{ extra_instructions }}
======
{%- endif %}
使 YAML :
```yaml
Code Documentation:
type: array
uniqueItems: true
items:
relevant file:
type: string
description: The full file path of the relevant file.
relevant line:
type: integer
description: |-
The relevant line number from a '__new hunk__' section where the {{ docs_for_language }} should be added.
doc placement:
type: string
enum:
- before
- after
description: |-
The {{ docs_for_language }} placement relative to the relevant line (code component).
For example, in Python the docs are placed after the function signature, but in Java they are placed before.
documentation:
type: string
description: |-
The {{ docs_for_language }} content. It should be complete, correctly formatted and indented, and without line numbers.
```
:
```yaml
Code Documentation:
- relevant file: |-
src/file1.py
relevant lines: 12
doc placement: after
documentation: |-
\"\"\"
This is a python docstring for func1.
\"\"\"
- ...
...
```
YAML , , ('|-').
, 'type' 'description' .
"""
user="""PR Info:
Title: '{{ title }}'
Branch: '{{ branch }}'
{%- if description %}
Description:
======
{{ description|trim }}
======
{%- endif %}
{%- if language %}
Main PR language: '{{language}}'
{%- endif %}
The PR Diff:
======
{{ diff|trim }}
======
Response (should be a valid YAML, and nothing else):
```yaml
"""

View File

@ -0,0 +1,166 @@
[pr_code_suggestions_prompt]
system=""" PR-Reviewer, Pull Request (PR) AI.
{%- if not focus_only_on_problems %}
diff, ( '+' ), bug , .
{%- else %}
diff, ( '+' ), bug .
{%- endif %}
PR :
======
## File: 'src/file1.py'
{%- if is_ai_metadata %}
### AI-generated changes summary:
* ...
* ...
{%- endif %}
@@ ... @@ def func1():
__new hunk__
unchanged code line0
unchanged code line1
+new code line2 added
unchanged code line3
__old hunk__
unchanged code line0
unchanged code line1
-old code line2 removed
unchanged code line3
@@ ... @@ def func2():
__new hunk__
unchanged code line4
+new code line5 added
unchanged code line6
## File: 'src/file2.py'
...
======
:
1. PR '__new hunk__' '__old hunk__' :
- '__new hunk__' PR .
- '__old hunk__' PR . , '__old hunk__' .
2. 使:
'+' ( '__new hunk__' )
'-' ( '__old hunk__' )
' ' ()
{%- if is_ai_metadata %}
3. , AI , . , .
{%- endif %}
{%- if not focus_only_on_problems %}
- {{num_code_suggestions}}.
{%- else %}
- {{num_code_suggestions}}. ,.
{%- endif %}
- '-','+'.
- PR( '__new hunk__' '+').
{%- if not focus_only_on_problems %}
- PR,. PR. ,.
- ,,使,使.
{%- else %}
- PR. ,.
- ,.
{%- endif %}
- (,),(``). :"验证`user_id`是否...".
- ,(PRdiff hunks),. ,().
{%- if extra_instructions %}
():
======
{{ extra_instructions }}
使,使!
======
{%- endif %}
YAML $PRCodeSuggestions, Pydantic:
=====
class CodeSuggestion(BaseModel):
relevant_file: str = Field(description="相关文件的完整路径")
language: str = Field(description="相关文件使用的编程语言")
suggestion_content: str = Field(description="一个可操作的建议,用于增强、改进或修复PR中引入的新代码. 不要在这里呈现实际的代码片段, 只需要建议. 简明扼要")
existing_code: str = Field(description="一个简短的代码片段, 来自PR更改后的 '__new hunk__' 部分, 该建议旨在增强或修复. 仅包括完整的代码行. 如果需要, 使用省略号 (...) 来保持简洁. 此片段应代表目标改进的特定PR代码.")
improved_code: str = Field(description="一个改进的代码片段, 在实施建议后替换 'existing_code' 片段.")
one_sentence_summary: str = Field(description="对建议的改进进行简明扼要的单句概述 (最多6个词). 关注 'what'. 保持通用性, 避免方法或变量名称,回答尽量使用简体中文,并且必须使用英文标点符号.")
{%- if not focus_only_on_problems %}
label: str = Field(description="一个单一的、描述性的标签, 最能描述建议类型. 可能的标签包括 '安全', '可能的错误', '可能的问题', '性能', '增强', '最佳实践', '可维护性', '拼写错误'. 其他相关标签也可以接受.")
{%- else %}
label: str = Field(description="一个单一的、描述性的标签, 最能描述建议类型. 可能的标签包括 '安全', '关键漏洞', '一般'. '一般' 部分应用于解决主要问题, 但不一定是关键级别的建议.")
{%- endif %}
class PRCodeSuggestions(BaseModel):
code_suggestions: List[CodeSuggestion]
=====
:
```yaml
code_suggestions:
- relevant_file: |
src/file1.py
language: |
python
suggestion_content: |
...
existing_code: |
...
improved_code: |
...
one_sentence_summary: |
...
label: |
...
```
YAML, , ('|').
"""
user="""--PR --
: '{{title}}'
{%- if date %}
: {{date}}
{%- endif %}
PR :
======
{{ diff_no_line_numbers|trim }}
======
{%- if duplicate_prompt_examples %}
:
```yaml
code_suggestions:
- relevant_file: |
src/file1.py
language: |
python
suggestion_content: |
...
existing_code: |
...
improved_code: |
...
one_sentence_summary: |
...
label: |
...
```
( '...' )
{%- endif %}
(YAML,):
```yaml
"""

View File

@ -0,0 +1,146 @@
[pr_code_suggestions_reflect_prompt]
system="""AI,Pull Request (PR).
PR,AI.,PR.
,PR,.,.PR.
:
1. 'one_sentence_summary' -
2. 'suggestion_content' - ,
3. 'existing_code' - PR__new hunk____,
4. 'improved_code' - ,'existing_code'
:
- PR
- 'improved_code','existing_code'
- PR
,0.
PR,.
,'existing_code''__new hunk__'.
:
- PR.,,,PR.
- ,,.
- 'existing_code'PR'__new hunk__'.
- 'improved_code','existing_code'.
- :
- ()(8-10).
- ,,(3-7).
- .
- ,.
:
- ,,1-2.
- 0:
- ,
- 使
- 使.
PR:
======
## File: 'src/file1.py'
{%- if is_ai_metadata %}
### AI-generated changes summary:
* ...
* ...
{%- endif %}
@@ ... @@ def func1():
__new hunk__
11 unchanged code line0
12 unchanged code line1
13 +new code line2 added
14 unchanged code line3
__old hunk__
unchanged code line0
unchanged code line1
-old code line2 removed
unchanged code line3
@@ ... @@ def func2():
__new hunk__
...
__old hunk__
...
## File: 'src/file2.py'
...
======
- ,'__new hunk__''__old hunk__'.'__new hunk__','__old hunk__'.,.
- '__new hunk__',便.,.
- : '+'PR, '-', ' '.
{%- if is_ai_metadata %}
- ,AI,.,.
{%- endif %}
$PRCodeSuggestionsFeedbackYAML,Pydantic:
=====
class CodeSuggestionFeedback(BaseModel):
suggestion_summary: str = Field(description="从输入重复")
relevant_file: str = Field(description="从输入重复")
relevant_lines_start: int = Field(description="相关的行号,来自'__new hunk__'部分,建议开始的位置(包括在内).应该从hunk行号中导出,并对应于相关的'现有代码'片段的开头")
relevant_lines_end: int = Field(description="相关的行号,来自'__new hunk__'部分,建议结束的位置(包括在内).应该从hunk行号中导出,并对应于相关的'现有代码'片段的结尾")
suggestion_score: int = Field(description="评估建议并分配一个从0到10的分数.如果建议是错误的,则给0.对于有效的建议,从1(最低影响/重要性)到10(最高影响/重要性)评分.")
why: str = Field(description="用1-2句话简要解释给出的分数,重点关注建议的影响,相关性和准确性.")
class PRCodeSuggestionsFeedback(BaseModel):
code_suggestions: List[CodeSuggestionFeedback]
=====
Example output:
```yaml
code_suggestions:
- suggestion_summary: |
使
relevant_file: "src/file1.py"
relevant_lines_start: 13
relevant_lines_end: 14
suggestion_score: 6
why: |
t,使.
- ...
```
YAML MUST ,, ('|').
"""
user="""Pull Request (PR):
======
{{ diff|trim }}
======
{{ num_code_suggestions }} AI, Pull Request:
======
{{ suggestion_str|trim }}
======
{%- if duplicate_prompt_examples %}
Example output:
```yaml
code_suggestions:
- suggestion_summary: |
...
relevant_file: "..."
relevant_lines_start: ...
relevant_lines_end: ...
suggestion_score: ...
why: |
...
- ...
```
('...')
{%- endif %}
(YAML,):
```yaml
"""

View File

@ -0,0 +1,86 @@
[pr_custom_labels_prompt]
system="""PR-Reviewer, Git Pull Request (PR).
PR.
{%- if enable_custom_labels %}
, PR.
{%- endif %}
{%- if extra_instructions %}
:
======
{{ extra_instructions }}
======
{% endif %}
$Labels YAML , Pydantic :
======
{%- if enable_custom_labels %}
{{ custom_labels_class }}
{%- else %}
class Label(str, Enum):
bug_fix = "Bug 修复"
tests = "测试"
enhancement = "增强"
documentation = "文档"
other = "其他"
{%- endif %}
class Labels(BaseModel):
labels: List[Label] = Field(min_items=0, description="选择描述PR内容的相关自定义标签, 并返回它们的键. 使用 Label 对象的值来更好地理解标签含义.")
======
:
```yaml
labels:
- ...
- ...
```
YAML,.
"""
user="""PR :
: '{{title}}'
: '{{ branch }}'
{%- if description %}
:
======
{{ description|trim }}
======
{%- endif %}
{%- if language %}
PR : '{{ language }}'
{%- endif %}
{%- if commit_messages_str %}
:
======
{{ commit_messages_str|trim }}
======
{%- endif %}
PR Git :
======
{{ diff|trim }}
======
, , : '-' , '+' , ' ' () .
(YAML, ):
```yaml
"""

View File

@ -0,0 +1,167 @@
[pr_description_prompt]
system="""PR-Reviewer, Git Pull Request (PR)
PR - , ,
- PR ('PR Git Diff''+')
- , 'Previous title', 'Previous description''Commit messages', , . , PR diff, .
- .
- , YAML使 ('|')
- , , 使 (`) (').
{%- if extra_instructions %}
:
=====
{{extra_instructions}}
=====
{% endif %}
$PRDescription YAML , Pydantic :
=====
class PRType(str, Enum):
bug_fix = "Bug 修复"
tests = "测试"
enhancement = "增强"
documentation = "文档"
other = "其他"
{%- if enable_custom_labels %}
{{ custom_labels_class }}
{%- endif %}
{%- if enable_semantic_files_types %}
class FileDescription(BaseModel):
filename: str = Field(description="相关文件的完整文件路径")
{%- if include_file_summary_changes %}
changes_summary: str = Field(description="相关文件中更改的简洁摘要, 以项目符号列出 (1-4 个项目符号)")
{%- endif %}
changes_title: str = Field(description="一行摘要 (5-10 个字) 概括文件中更改的主题")
label: str = Field(description="代表文件中发生的代码更改类型的单个语义标签, 可能的值 (部分列表): 'Bug 修复', '测试', '增强', '文档', '错误处理', '配置更改', '依赖', '格式化', '杂项', ...")
{%- endif %}
class PRDescription(BaseModel):
type: List[PRType] = Field(description="描述 PR 内容的一种或多种类型, 返回 label 成员值 (例如 'Bug 修复', 而不是 'bug_修复')")
description: str = Field(description="最多用四个项目符号概括 PR 更改, 每个项目符号最多 8 个字, 对于大型 PR, 如果需要, 添加子项目符号, 按重要性对项目符号排序, 每个项目符号突出显示一个关键更改组")
title: str = Field(description="一个简洁且描述性的标题, 概括了 PR 的主要主题")
{%- if enable_semantic_files_types %}
pr_files: List[FileDescription] = Field(max_items=20, description="PR 中更改的所有文件的列表, 以及其更改的摘要, 必须分析每个文件, 无论更改大小")
{%- endif %}
=====
:
```yaml
type:
- ...
- ...
description: |
...
title: |
...
{%- if enable_semantic_files_types %}
pr_files:
- filename: |
...
{%- if include_file_summary_changes %}
changes_summary: |
...
{%- endif %}
changes_title: |
...
label: |
label_key_1
...
{%- endif %}
```
YAML, . YAML, ('|')
"""
user="""
{%- if related_tickets %}
:
{% for ticket in related_tickets %}
=====
: '{{ ticket.title }}'
{%- if ticket.labels %}
: {{ ticket.labels }}
{%- endif %}
{%- if ticket.body %}
:
#####
{{ ticket.body }}
#####
{%- endif %}
=====
{% endfor %}
{%- endif %}
PR :
: '{{title}}'
{%- if description %}
:
=====
{{ description|trim }}
=====
{%- endif %}
: '{{branch}}'
{%- if commit_messages_str %}
:
=====
{{ commit_messages_str|trim }}
=====
{%- endif %}
The PR Git Diff:
=====
{{ diff|trim }}
=====
, diff : '-' , '+' , ' ' () .
{%- if duplicate_prompt_examples %}
:
```yaml
type:
- Bug fix
- Refactoring
- ...
description: |
...
title: |
...
{%- if enable_semantic_files_types %}
pr_files:
- filename: |
...
{%- if include_file_summary_changes %}
changes_summary: |
...
{%- endif %}
changes_title: |
...
label: |
label_key_1
...
{%- endif %}
```
( '...' )
{%- endif %}
(YAML, ):
```yaml
"""

View File

@ -0,0 +1,68 @@
[pr_evaluate_prompt]
prompt="""\
PR,,(PR).
:
***** *****
{{pr_task|trim}}
***** *****
1 :
***** 1 *****
{{pr_response1|trim}}
***** 1 *****
2 :
***** 2 *****
{{pr_response2|trim}}
***** 2 *****
:
- .,PR.
- 12.,,.
,.:
- ?
- PR?
- ,?
- ,?
- .,,.
YAML,$PRRankRespones,Pydantic:
=====
class PRRankRespones(BaseModel):
which_response_was_better: Literal[0, 1, 2] = Field(description="一个数字,指示哪个响应更好.0表示两个响应同样好.")
why: str = Field(description="以简短明了的方式,解释为什么选择的响应比另一个更好.如果相关,请具体说明并举例.")
score_response1: int = Field(description="一个介于1到10之间的分数,根据提示中提到的标准,指示response1的质量.")
score_response2: int = Field(description="一个介于1到10之间的分数,根据提示中提到的标准,指示response2的质量.")
=====
:
```yaml
which_response_was_better: "X"
why: "响应 X 更好,因为它更实用,并且更好地满足任务要求,因为 ..."
score_response1: ...
score_response2: ...
```
(YAML,):
```yaml
"""

View File

@ -0,0 +1,53 @@
[pr_help_prompts]
system="""你是一名Doc-helper, 一个被设计用来回答关于名为"PR-Agent"(最近重命名为"Qodo Merge").
, .
使.
:
- . , .
- PR-Agent'describe', 'review', 'improve'. , .
- , , , , , .
YAML, $DocHelper, Pydantic:
=====
class relevant_section(BaseModel):
file_name: str = Field(description="相关文件的名称")
relevant_section_header_string: str = Field(description="来自相关文件的相关markdown章节标题的确切文本 (以'#', '##'等开头). 如果整个文件是相关章节, 或者相关章节没有标题, 则返回空字符串")
class DocHelper(BaseModel):
user_question: str = Field(description="用户的问题")
response: str = Field(description="对用户问题的回复")
relevant_sections: List[relevant_section] = Field(description="文档中回答用户问题的相关markdown章节列表, 按相关性排序 (最相关的在前)")
=====
:
```yaml
user_question: |
...
response: |
...
relevant_sections:
- file_name: "src/file1.py"
relevant_section_header_string: |
...
- ...
"""
user="""\
:
=====
{{ question|trim }}
=====
:
=====
{{ snippets|trim }}
=====
(YAML, ):
```yaml
"""

View File

@ -0,0 +1,53 @@
[pr_information_from_user_prompt]
system="""PR-Reviewer, Git Pull Request(PR).
PRPR Git Diff, PR3PR.
PR, , , , PR.
\\, . , , .
:
'
PR:
1) ...
2) ...
...
'
"""
user="""PR :
: '{{title}}'
: '{{branch}}'
{%- if description %}
:
======
{{ description|trim }}
======
{%- endif %}
{%- if language %}
PR : '{{ language }}'
{%- endif %}
{%- if commit_messages_str %}
:
======
{{ commit_messages_str|trim }}
======
{%- endif %}
PR Git :
======
{{ diff|trim }}
======
diff, : '-', '+', ' '()
:
"""

View File

@ -0,0 +1,53 @@
[pr_line_questions_prompt]
system="""PR, Git Pull Request (PR) .
PR/, .
, , . .
. , .
:
- , 使 (`) (').
- , 使.
- .
:
======
## File: 'src/file1.py'
@@ -12,5 +12,5 @@ def func1():
code line 1 that remained unchanged in the PR
code line 2 that remained unchanged in the PR
-code line that was removed in the PR
+code line added in the PR
code line 3 that remained unchanged in the PR
======
"""
user="""PR :
: '{{title}}'
: '{{branch}}'
PR:
======
{{ full_hunk|trim }}
======
:
======
{{ selected_lines|trim }}
======
, , : '-' , '+' , ' ' ()
:
======
{{ question|trim }}
======
:
"""

View File

@ -0,0 +1,44 @@
[pr_questions_prompt]
system="""PR,Git Pull Request(PR).
PR\\('PR Git Diff''+'),.
,,.
.
.
,.
"""
user="""PR :
: '{{title}}'
: '{{branch}}'
{%- if description %}
:
======
{{ description|trim }}
======
{%- endif %}
{%- if language %}
PR : '{{ language }}'
{%- endif %}
The PR Git Diff:
======
{{ diff|trim }}
======
diff,: '-', '+', ' ' ()
PR :
======
{{ questions|trim }}
======
PR :
"""

View File

@ -0,0 +1,283 @@
[pr_review_prompt]
system="""PR-Reviewer, Git Pull Request (PR) .
PR .
PR ( '+' )
PR :
======
## File: 'src/file1.py'
{%- if is_ai_metadata %}
### AI-生成的更改摘要:
* ...
* ...
{%- endif %}
@@ ... @@ def func1():
__new hunk__
11 unchanged code line0
12 unchanged code line1
13 +new code line2 added
14 unchanged code line3
__old hunk__
unchanged code line0
unchanged code line1
-old code line2 removed
unchanged code line3
@@ ... @@ def func2():
__new hunk__
unchanged code line4
+new code line5 removed
unchanged code line6
## File: 'src/file2.py'
...
======
- , '__new hunk__' '__old hunk__' , .'__new hunk__' , '__old hunk__' ., __old hunk__ .
- '__new hunk__' , ., .
- ('+', '-', ' ') .'+' PR , '-' PR , ' ' . \
PR ( '+' )
{%- if is_ai_metadata %}
- , AI ., .
{%- endif %}
- , , 使 (`) (').
{%- if extra_instructions %}
:
======
{{ extra_instructions }}
======
{% endif %}
YAML , $PRReview , Pydantic :
=====
{%- if require_can_be_split_review %}
class SubPR(BaseModel):
relevant_files: List[str] = Field(description="子 PR 的相关文件")
title: str = Field(description="独立且有意义的子 PR 的简短标题, 仅由相关文件组成")
{%- endif %}
class KeyIssuesComponentLink(BaseModel):
relevant_file: str = Field(description="相关文件的完整文件路径")
issue_header: str = Field(description="问题的标题, 一到两个词.例如: 'Possible Bug' 等")
issue_content: str = Field(description="关于应该在 PR 审查过程中进一步检查和验证的内容的简短而简洁的摘要.不要在此字段中引用行号.")
start_line: int = Field(description="相关文件中与此问题对应的起始行")
end_line: int = Field(description="相关文件中与此问题对应的结束行")
{%- if related_tickets %}
class TicketCompliance(BaseModel):
ticket_url: str = Field(description="工单 URL 或 ID")
ticket_requirements: str = Field(description="用你自己的话 (以项目符号) 重复工单提出的所有要求, 子任务, DoD 和验收标准")
fully_compliant_requirements: str = Field(description="上面 'ticket_requirements' 部分的项目列表中, PR 代码满足的项目.不要解释如何满足要求, 只简短地列出它们即可.可以为空")
not_compliant_requirements: str = Field(description="上面 'ticket_requirements' 部分的项目列表中, PR 代码未满足的项目.不要解释如何不满足要求, 只简短地列出它们即可.可以为空")
requires_further_human_verification: str = Field(description="上面 'ticket_requirements' 部分的项目列表中, 无法仅通过代码审查进行评估, 不明确或需要进一步人工审查 (例如, 浏览器测试, UI 检查) 的项目.如果所有 'ticket_requirements' 都被标记为完全符合或不符合, 则留空")
{%- endif %}
class Review(BaseModel):
{%- if related_tickets %}
ticket_compliance_check: List[TicketCompliance] = Field(description="相关工单的合规性检查列表")
{%- endif %}
{%- if require_estimate_effort_to_review %}
estimated_effort_to_review_[1-5]: int = Field(description="在 1-5 (包括 1 和 5) 的范围内估计经验丰富且知识渊博的开发人员审查此 PR 所需的时间和精力.1 表示简短且容易审查, 5 表示漫长且困难的审查.考虑到 PR 代码差异的大小, 复杂性, 质量和所需的更改.")
{%- endif %}
{%- if require_score %}
score: str = Field(description="在 0-100 (包括 0 和 100) 的范围内对此 PR 进行评分, 其中 0 表示最差的 PR 代码, 而 100 表示最高质量的 PR 代码, 没有任何错误或性能问题, 可以立即合并并在生产环境中大规模运行.")
{%- endif %}
{%- if require_tests %}
relevant_tests: str = Field(description="是\\否 问题: 此 PR 是否添加或更新了相关测试?")
{%- endif %}
{%- if question_str %}
insights_from_user_answers: str = Field(description="简要总结你从用户对问题的回答中获得的见解")
{%- endif %}
key_issues_to_review: List[KeyIssuesComponentLink] = Field("PR 代码中引入的需要 PR 审查员进一步关注和验证的高优先级错误, 问题或性能问题的简短且多样的列表 (0-3 个问题),内容尽量使用简体中文和英文标点符号.")
{%- if require_security_review %}
security_concerns: str = Field(description="此 PR 代码是否引入了可能的漏洞, 例如敏感信息 (例如, API 密钥, 秘密, 密码) 的暴露, 或安全问题, 如 SQL 注入, XSS, CSRF 和其他 ? 如果没有可能的问题, 回答 'No' (不解释原因).如果存在安全隐患或问题, 请以简短的标题开头回答, 例如: '敏感信息泄露: ...', 'SQL 注入: ...' 等.解释你的答案.如果可能, 请具体说明并举例说明,内容尽量使用简体中文和英文标点符号.")
{%- endif %}
{%- if require_can_be_split_review %}
can_be_split: List[SubPR] = Field(min_items=0, max_items=3, description="这个 PR 总共包含 {{ num_pr_files }} 个更改文件, 是否可以将其划分为更小的子 PR, 这些子 PR 具有可以独立审查和合并的不同任务, 而不考虑顺序 ? 确保子 PR 确实是独立的, 彼此之间没有代码依赖关系, 并且每个子 PR 都代表一个有意义的独立任务.如果 PR 代码不需要拆分, 则输出一个空列表.")
{%- endif %}
class PRReview(BaseModel):
review: Review
=====
:
```yaml
review:
{%- if related_tickets %}
ticket_compliance_check:
- ticket_url: |
...
ticket_requirements: |
...
fully_compliant_requirements: |
...
not_compliant_requirements: |
...
overall_compliance_level: |
...
{%- endif %}
{%- if require_estimate_effort_to_review %}
estimated_effort_to_review_[1-5]: |
3
{%- endif %}
{%- if require_score %}
score: 89
{%- endif %}
relevant_tests: |
No
key_issues_to_review:
- relevant_file: |
directory/xxx.py
issue_header: |
Bug
issue_content: |
...
start_line: 12
end_line: 14
- ...
security_concerns: |
No
{%- if require_can_be_split_review %}
can_be_split:
- relevant_files:
- ...
- ...
title: ...
- ...
{%- endif %}
```
YAML, . YAML , ('|')
"""
user="""
{%- if related_tickets %}
--PR Ticket Info--
{%- for ticket in related_tickets %}
=====
Ticket URL: '{{ ticket.ticket_url }}'
Ticket Title: '{{ ticket.title }}'
{%- if ticket.labels %}
Ticket Labels: {{ ticket.labels }}
{%- endif %}
{%- if ticket.body %}
Ticket Description:
#####
{{ ticket.body }}
#####
{%- endif %}
=====
{% endfor %}
{%- endif %}
--PR --
{%- if date %}
: {{date}}
{%- endif %}
: '{{title}}'
: '{{branch}}'
{%- if description %}
PR :
======
{{ description|trim }}
======
{%- endif %}
{%- if question_str %}
=====
PR ...
{{ question_str|trim }}
:
'
{{ answer_str|trim }}
'
=====
{%- endif %}
PR :
======
{{ diff|trim }}
======
{%- if duplicate_prompt_examples %}
:
```yaml
review:
{%- if related_tickets %}
ticket_compliance_check:
- ticket_url: |
...
ticket_requirements: |
...
fully_compliant_requirements: |
...
not_compliant_requirements: |
...
overall_compliance_level: |
...
{%- endif %}
{%- if require_estimate_effort_to_review %}
estimated_effort_to_review_[1-5]: |
3
{%- endif %}
{%- if require_score %}
score: 89
{%- endif %}
relevant_tests: |
No
key_issues_to_review:
- relevant_file: |
...
issue_header: |
...
issue_content: |
...
start_line: ...
end_line: ...
- ...
security_concerns: |
No
{%- if require_can_be_split_review %}
can_be_split:
- relevant_files:
- ...
- ...
title: ...
- ...
{%- endif %}
```
( '...' )
{%- endif %}
(YAML, ):
```yaml
"""

View File

@ -0,0 +1,46 @@
[pr_sort_code_suggestions_prompt]
system="""
"""
user=""", Git Pull Request (PR):
======
{{ suggestion_str|trim }}
======
,.
,.
PR ,,.
使 YAML :
```yaml
Sort Order:
type: array
maxItems: {{ suggestion_list|length }}
uniqueItems: true
items:
suggestion number:
type: integer
minimum: 1
maximum: {{ suggestion_list|length }}
importance order:
type: integer
minimum: 1
maximum: {{ suggestion_list|length }}
```
:
```yaml
Sort Order:
- suggestion number: 1
importance order: 2
- suggestion number: 2
importance order: 3
- suggestion number: 3
importance order: 1
```
YAML.,使 ('|').
, 'type' 'description' .
( YAML,):
```yaml
"""

View File

@ -0,0 +1,70 @@
[pr_update_changelog_prompt]
system="""PR-Changelog-Updater
CHANGELOG.mdPR:
- ,,
- ()
- ,, ,3-4
- CHANGELOG.md,
{%- if pr_link %}
- , 使PR URL '{{ pr_link }}' : [*][pr_link]
{%- endif %}
{%- if extra_instructions %}
:
======
{{ extra_instructions|trim }}
======
{%- endif %}
"""
user="""PR :
: '{{title}}'
: '{{branch}}'
{%- if description %}
:
======
{{ description|trim }}
======
{%- endif %}
{%- if language %}
PR: '{{ language }}'
{%- endif %}
{%- if commit_messages_str %}
:
======
{{ commit_messages_str|trim }}
======
{%- endif %}
PR Git :
======
{{ diff|trim }}
======
:
```
{{today}}
```
'CHANGELOG.md'
======
{{ changelog_file_str }}
======
:
```markdown
"""

View File

View File

@ -0,0 +1,180 @@
import copy
import textwrap
from functools import partial
from typing import Dict
from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import load_yaml
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import get_git_provider
from utils.pr_agent.git_providers.git_provider import get_main_pr_language
from utils.pr_agent.log import get_logger
class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"extra_instructions": get_settings().pr_add_docs.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
'docs_for_language': get_docs_for_language(self.main_language,
get_settings().pr_add_docs.docs_style),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_add_docs_prompt.system,
get_settings().pr_add_docs_prompt.user)
async def run(self):
try:
get_logger().info('Generating code Docs for PR...')
if get_settings().config.publish_output:
self.git_provider.publish_comment("生成文档中...", is_temporary=True)
get_logger().info('Preparing PR documentation...')
await retry_with_fallback_models(self._prepare_prediction)
data = self._prepare_pr_code_docs()
if (not data) or (not 'Code Documentation' in data):
get_logger().info('No code documentation found for PR.')
return
if get_settings().config.publish_output:
get_logger().info('Pushing PR documentation...')
self.git_provider.remove_initial_comment()
get_logger().info('Pushing inline code documentation...')
self.push_inline_docs(data)
except Exception as e:
get_logger().error(f"Failed to generate code documentation for PR, error: {e}")
async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=False)
get_logger().info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
async def _get_prediction(self, model: str):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
get_logger().info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
return response
def _prepare_pr_code_docs(self) -> Dict:
docs = self.prediction.strip()
data = load_yaml(docs)
if isinstance(data, list):
data = {'Code Documentation': data}
return data
def push_inline_docs(self, data):
docs = []
if not data['Code Documentation']:
return self.git_provider.publish_comment('No code documentation found to improve this PR.')
for d in data['Code Documentation']:
try:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"add_docs: {d}")
relevant_file = d['relevant file'].strip()
relevant_line = int(d['relevant line']) # absolute position
documentation = d['documentation']
doc_placement = d['doc placement'].strip()
if documentation:
new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement,
add_original_line=True)
body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```"
docs.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_line,
'relevant_lines_end': relevant_line})
except Exception:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not parse code docs: {d}")
is_successful = self.git_provider.publish_code_suggestions(docs)
if not is_successful:
get_logger().info("Failed to publish code docs, trying to publish each docs separately")
for doc_suggestion in docs:
self.git_provider.publish_code_suggestions([doc_suggestion])
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after',
add_original_line=False):
try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
else self.git_provider.get_diff_files()
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
break
if original_initial_line:
if doc_placement == 'after':
line = file.head_file.splitlines()[relevant_lines_start]
else:
line = original_initial_line
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(line) - len(line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
if add_original_line:
if doc_placement == 'after':
new_code_snippet = original_initial_line + "\n" + new_code_snippet
else:
new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
return new_code_snippet
def get_docs_for_language(language, style):
language = language.lower()
if language == 'java':
return "Javadocs"
elif language in ['python', 'lisp', 'clojure']:
return f"Docstring ({style})"
elif language in ['javascript', 'typescript']:
return "JSdocs"
elif language == 'c++':
return "Doxygen"
else:
return "Docs"

View File

@ -0,0 +1,872 @@
import asyncio
import copy
import difflib
import re
import textwrap
import traceback
from datetime import datetime
from functools import partial
from typing import Dict, List
from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
get_pr_diff, get_pr_multi_diffs,
retry_with_fallback_models)
from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags,
show_relevant_configurations)
from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (AzureDevopsProvider, GithubProvider,
get_git_provider_with_context)
from utils.pr_agent.git_providers.git_provider import get_main_pr_language, GitProvider
from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage
from utils.pr_agent.tools.pr_description import insert_br_after_x_chars
class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider_with_context(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens_original = get_settings().config.max_model_tokens
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
num_code_suggestions = int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk)
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
self.pr_url = pr_url
self.cli_mode = cli_mode
self.pr_description, self.pr_description_files = (
self.git_provider.get_pr_description(split_changes_walkthrough=True))
if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
get_settings().get("config.enable_ai_metadata", False)):
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
get_logger().debug(f"AI metadata added to the this command")
else:
get_settings().set("config.enable_ai_metadata", False)
get_logger().debug(f"AI metadata is disabled for this command")
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.pr_description,
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"diff_no_line_numbers": "", # empty diff for initial calculation
"num_code_suggestions": num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
"relevant_best_practices": "",
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
"focus_only_on_problems": get_settings().get("pr_code_suggestions.focus_only_on_problems", False),
"date": datetime.now().strftime('%Y-%m-%d'),
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
}
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
self.pr_code_suggestions_prompt_system,
get_settings().pr_code_suggestions_prompt.user)
self.progress = f"## 生成 PR 代码建议\n\n"
self.progress += f"""\n思考中 ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
self.progress_response = None
async def run(self):
try:
if not self.git_provider.get_files():
get_logger().info(f"PR has no files: {self.pr_url}, skipping code suggestions")
return None
get_logger().info('Generating code suggestions for PR...')
relevant_configs = {'pr_code_suggestions': dict(get_settings().pr_code_suggestions),
'config': dict(get_settings().config)}
get_logger().debug("Relevant configs", artifacts=relevant_configs)
# publish "Preparing suggestions..." comments
if (get_settings().config.publish_output and get_settings().config.publish_output_progress and
not get_settings().config.get('is_auto_command', False)):
if self.git_provider.is_supported("gfm_markdown"):
self.progress_response = self.git_provider.publish_comment(self.progress)
else:
self.git_provider.publish_comment("准备建议中...", is_temporary=True)
# call the model to get the suggestions, and self-reflect on them
if not self.is_extended:
data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
if not data:
data = {"code_suggestions": []}
self.data = data
# Handle the case where the PR has no suggestions
if (data is None or 'code_suggestions' not in data or not data['code_suggestions']):
await self.publish_no_suggestions()
return
# publish the suggestions
if get_settings().config.publish_output:
# If a temporary comment was published, remove it
self.git_provider.remove_initial_comment()
# Publish table summarized suggestions
if ((not get_settings().pr_code_suggestions.commitable_code_suggestions) and
self.git_provider.is_supported("gfm_markdown")):
# generate summarized suggestions
pr_body = self.generate_summarized_suggestions(data)
get_logger().debug(f"PR output", artifact=pr_body)
# require self-review
if get_settings().pr_code_suggestions.demand_code_suggestions_self_review:
pr_body = await self.add_self_review_text(pr_body)
# add usage guide
if (get_settings().pr_code_suggestions.enable_chat_text and get_settings().config.is_auto_command
and isinstance(self.git_provider, GithubProvider)):
pr_body += "\n\n>💡 Need additional feedback ? start a [PR chat](https://chromewebstore.google.com/detail/ephlnjeghhogofkifjloamocljapahnl) \n\n"
if get_settings().pr_code_suggestions.enable_help_text:
pr_body += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
pr_body += HelpMessage.get_improve_usage_guide()
pr_body += "\n</details>\n"
# Output the relevant configurations if enabled
if get_settings().get('config', {}).get('output_relevant_configurations', False):
pr_body += show_relevant_configurations(relevant_section='pr_code_suggestions')
# publish the PR comment
if get_settings().pr_code_suggestions.persistent_comment: # true by default
self.publish_persistent_comment_with_history(self.git_provider,
pr_body,
initial_header="## PR 代码建议 ✨",
update_header=True,
name="suggestions",
final_update_message=False,
max_previous_comments=get_settings().pr_code_suggestions.max_history_len,
progress_response=self.progress_response)
else:
if self.progress_response:
self.git_provider.edit_comment(self.progress_response, body=pr_body)
else:
self.git_provider.publish_comment(pr_body)
# dual publishing mode
if int(get_settings().pr_code_suggestions.dual_publishing_score_threshold) > 0:
await self.dual_publishing(data)
else:
await self.push_inline_code_suggestions(data)
if self.progress_response:
self.git_provider.remove_comment(self.progress_response)
else:
get_logger().info('Code suggestions generated for PR, but not published since publish_output is False.')
pr_body = self.generate_summarized_suggestions(data)
get_settings().data = {"artifact": pr_body}
return
except Exception as e:
get_logger().error(f"Failed to generate code suggestions for PR, error: {e}",
artifact={"traceback": traceback.format_exc()})
if get_settings().config.publish_output:
if self.progress_response:
self.progress_response.delete()
else:
try:
self.git_provider.remove_initial_comment()
self.git_provider.publish_comment(f"Failed to generate code suggestions for PR")
except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}")
async def add_self_review_text(self, pr_body):
text = get_settings().pr_code_suggestions.code_suggestions_self_review_text
pr_body += f"\n\n- [ ] {text}"
approve_pr_on_self_review = get_settings().pr_code_suggestions.approve_pr_on_self_review
fold_suggestions_on_self_review = get_settings().pr_code_suggestions.fold_suggestions_on_self_review
if approve_pr_on_self_review and not fold_suggestions_on_self_review:
pr_body += ' <!-- approve pr self-review -->'
elif fold_suggestions_on_self_review and not approve_pr_on_self_review:
pr_body += ' <!-- fold suggestions self-review -->'
else:
pr_body += ' <!-- approve and fold suggestions self-review -->'
return pr_body
async def publish_no_suggestions(self):
pr_body = "## PR 代码建议 ✨\n\n未找到该PR的代码建议."
if get_settings().config.publish_output and get_settings().config.publish_output_no_suggestions:
get_logger().warning('No code suggestions found for the PR.')
get_logger().debug(f"PR output", artifact=pr_body)
if self.progress_response:
self.git_provider.edit_comment(self.progress_response, body=pr_body)
else:
self.git_provider.publish_comment(pr_body)
else:
get_settings().data = {"artifact": ""}
async def dual_publishing(self, data):
data_above_threshold = {'code_suggestions': []}
try:
for suggestion in data['code_suggestions']:
if int(suggestion.get('score', 0)) >= int(
get_settings().pr_code_suggestions.dual_publishing_score_threshold) \
and suggestion.get('improved_code'):
data_above_threshold['code_suggestions'].append(suggestion)
if not data_above_threshold['code_suggestions'][-1]['existing_code']:
get_logger().info(f'Identical existing and improved code for dual publishing found')
data_above_threshold['code_suggestions'][-1]['existing_code'] = suggestion[
'improved_code']
if data_above_threshold['code_suggestions']:
get_logger().info(
f"Publishing {len(data_above_threshold['code_suggestions'])} suggestions in dual publishing mode")
await self.push_inline_code_suggestions(data_above_threshold)
except Exception as e:
get_logger().error(f"Failed to publish dual publishing suggestions, error: {e}")
@staticmethod
def publish_persistent_comment_with_history(git_provider: GitProvider,
pr_comment: str,
initial_header: str,
update_header: bool = True,
name='review',
final_update_message=True,
max_previous_comments=4,
progress_response=None,
only_fold=False):
def _extract_link(comment_text: str):
r = re.compile(r"<!--.*?-->")
match = r.search(comment_text)
up_to_commit_txt = ""
if match:
up_to_commit_txt = f" up to commit {match.group(0)[4:-3].strip()}"
return up_to_commit_txt
if isinstance(git_provider, AzureDevopsProvider): # get_latest_commit_url is not supported yet
if progress_response:
git_provider.edit_comment(progress_response, pr_comment)
new_comment = progress_response
else:
new_comment = git_provider.publish_comment(pr_comment)
return new_comment
history_header = f"#### Previous suggestions\n"
last_commit_num = git_provider.get_latest_commit_url().split('/')[-1][:7]
if only_fold: # A user clicked on the 'self-review' checkbox
text = get_settings().pr_code_suggestions.code_suggestions_self_review_text
latest_suggestion_header = f"\n\n- [x] {text}"
else:
latest_suggestion_header = f"Latest suggestions up to {last_commit_num}"
latest_commit_html_comment = f"<!-- {last_commit_num} -->"
found_comment = None
if max_previous_comments > 0:
try:
prev_comments = list(git_provider.get_issue_comments())
for comment in prev_comments:
if comment.body.startswith(initial_header):
prev_suggestions = comment.body
found_comment = comment
comment_url = git_provider.get_comment_url(comment)
if history_header.strip() not in comment.body:
# no history section
# extract everything between <table> and </table> in comment.body including <table> and </table>
table_index = comment.body.find("<table>")
if table_index == -1:
git_provider.edit_comment(comment, pr_comment)
continue
# find http link from comment.body[:table_index]
up_to_commit_txt = _extract_link(comment.body[:table_index])
prev_suggestion_table = comment.body[
table_index:comment.body.rfind("</table>") + len("</table>")]
tick = "" if "" in prev_suggestion_table else ""
# surround with details tag
prev_suggestion_table = f"<details><summary>{tick}{name.capitalize()}{up_to_commit_txt}</summary>\n<br>{prev_suggestion_table}\n\n</details>"
new_suggestion_table = pr_comment.replace(initial_header, "").strip()
pr_comment_updated = f"{initial_header}\n{latest_commit_html_comment}\n\n"
pr_comment_updated += f"{latest_suggestion_header}\n{new_suggestion_table}\n\n___\n\n"
pr_comment_updated += f"{history_header}{prev_suggestion_table}\n"
else:
# get the text of the previous suggestions until the latest commit
sections = prev_suggestions.split(history_header.strip())
latest_table = sections[0].strip()
prev_suggestion_table = sections[1].replace(history_header, "").strip()
# get text after the latest_suggestion_header in comment.body
table_ind = latest_table.find("<table>")
up_to_commit_txt = _extract_link(latest_table[:table_ind])
latest_table = latest_table[table_ind:latest_table.rfind("</table>") + len("</table>")]
# enforce max_previous_comments
count = prev_suggestions.count(f"\n<details><summary>{name.capitalize()}")
count += prev_suggestions.count(f"\n<details><summary>✅ {name.capitalize()}")
if count >= max_previous_comments:
# remove the oldest suggestion
prev_suggestion_table = prev_suggestion_table[:prev_suggestion_table.rfind(
f"<details><summary>{name.capitalize()} up to commit")]
tick = "" if "" in latest_table else ""
# Add to the prev_suggestions section
last_prev_table = f"\n<details><summary>{tick}{name.capitalize()}{up_to_commit_txt}</summary>\n<br>{latest_table}\n\n</details>"
prev_suggestion_table = last_prev_table + "\n" + prev_suggestion_table
new_suggestion_table = pr_comment.replace(initial_header, "").strip()
pr_comment_updated = f"{initial_header}\n"
pr_comment_updated += f"{latest_commit_html_comment}\n\n"
pr_comment_updated += f"{latest_suggestion_header}\n\n{new_suggestion_table}\n\n"
pr_comment_updated += "___\n\n"
pr_comment_updated += f"{history_header}\n"
pr_comment_updated += f"{prev_suggestion_table}\n"
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
if progress_response: # publish to 'progress_response' comment, because it refreshes immediately
git_provider.edit_comment(progress_response, pr_comment_updated)
git_provider.remove_comment(comment)
comment = progress_response
else:
git_provider.edit_comment(comment, pr_comment_updated)
return comment
except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}")
pass
# if we are here, we did not find a previous comment to update
body = pr_comment.replace(initial_header, "").strip()
pr_comment = f"{initial_header}\n\n{latest_commit_html_comment}\n\n{body}\n\n"
if progress_response:
git_provider.edit_comment(progress_response, pr_comment)
new_comment = progress_response
else:
new_comment = git_provider.publish_comment(pr_comment)
return new_comment
def extract_link(self, s):
r = re.compile(r"<!--.*?-->")
match = r.search(s)
up_to_commit_txt = ""
if match:
up_to_commit_txt = f" up to commit {match.group(0)[4:-3].strip()}"
return up_to_commit_txt
async def _prepare_prediction(self, model: str) -> dict:
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=False)
self.patches_diff_list = [self.patches_diff]
self.patches_diff_no_line_number = self.remove_line_numbers([self.patches_diff])[0]
if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, self.patches_diff, self.patches_diff_no_line_number)
else:
get_logger().warning(f"Empty PR diff")
self.prediction = None
data = self.prediction
return data
async def _get_prediction(self, model: str, patches_diff: str, patches_diff_no_line_number: str) -> dict:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
variables["diff_no_line_numbers"] = patches_diff_no_line_number # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(self.pr_code_suggestions_prompt_system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
if not get_settings().config.publish_output:
get_settings().system_prompt = system_prompt
get_settings().user_prompt = user_prompt
# load suggestions from the AI response
data = self._prepare_pr_code_suggestions(response)
# self-reflect on suggestions (mandatory, since line numbers are generated now here)
model_reflection = get_settings().config.model
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"],
patches_diff, model=model_reflection)
if response_reflect:
await self.analyze_self_reflection_response(data, response_reflect)
else:
# get_logger().error(f"Could not self-reflect on suggestions. using default score 7")
for i, suggestion in enumerate(data["code_suggestions"]):
suggestion["score"] = 7
suggestion["score_why"] = ""
return data
async def analyze_self_reflection_response(self, data, response_reflect):
response_reflect_yaml = load_yaml(response_reflect)
code_suggestions_feedback = response_reflect_yaml.get("code_suggestions", [])
if code_suggestions_feedback and len(code_suggestions_feedback) == len(data["code_suggestions"]):
for i, suggestion in enumerate(data["code_suggestions"]):
try:
suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
if 'relevant_lines_start' not in suggestion:
relevant_lines_start = code_suggestions_feedback[i].get('relevant_lines_start', -1)
relevant_lines_end = code_suggestions_feedback[i].get('relevant_lines_end', -1)
suggestion['relevant_lines_start'] = relevant_lines_start
suggestion['relevant_lines_end'] = relevant_lines_end
if relevant_lines_start < 0 or relevant_lines_end < 0:
suggestion["score"] = 0
try:
if get_settings().config.publish_output:
if not suggestion["score"]:
score = -1
else:
score = int(suggestion["score"])
label = suggestion["label"].lower().strip()
label = label.replace('<br>', ' ')
suggestion_statistics_dict = {'score': score,
'label': label}
get_logger().info(f"PR-Agent suggestions statistics",
statistics=suggestion_statistics_dict, analytics=True)
except Exception as e:
get_logger().error(f"Failed to log suggestion statistics, error: {e}")
pass
except Exception as e: #
get_logger().error(f"Error processing suggestion score {i}",
artifact={"suggestion": suggestion,
"code_suggestions_feedback": code_suggestions_feedback[i]})
suggestion["score"] = 7
suggestion["score_why"] = ""
# if the before and after code is the same, clear one of them
try:
if suggestion['existing_code'] == suggestion['improved_code']:
get_logger().debug(
f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
if get_settings().pr_code_suggestions.commitable_code_suggestions:
suggestion['improved_code'] = "" # we need 'existing_code' to locate the code in the PR
else:
suggestion['existing_code'] = ""
except Exception as e:
get_logger().error(f"Error processing suggestion {i + 1}, error: {e}")
@staticmethod
def _truncate_if_needed(suggestion):
max_code_suggestion_length = get_settings().get("PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0)
suggestion_truncation_message = get_settings().get("PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "")
if max_code_suggestion_length > 0:
if len(suggestion['improved_code']) > max_code_suggestion_length:
get_logger().info(f"Truncated suggestion from {len(suggestion['improved_code'])} "
f"characters to {max_code_suggestion_length} characters")
suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length]
suggestion['improved_code'] += f"\n{suggestion_truncation_message}"
return suggestion
def _prepare_pr_code_suggestions(self, predictions: str) -> Dict:
data = load_yaml(predictions.strip(),
keys_fix_yaml=["relevant_file", "suggestion_content", "existing_code", "improved_code"],
first_key="code_suggestions", last_key="label")
if isinstance(data, list):
data = {'code_suggestions': data}
# remove or edit invalid suggestions
suggestion_list = []
one_sentence_summary_list = []
for i, suggestion in enumerate(data['code_suggestions']):
try:
needed_keys = ['one_sentence_summary', 'label', 'relevant_file']
is_valid_keys = True
for key in needed_keys:
if key not in suggestion:
is_valid_keys = False
get_logger().debug(
f"Skipping suggestion {i + 1}, because it does not contain '{key}':\n'{suggestion}")
break
if not is_valid_keys:
continue
if get_settings().get("pr_code_suggestions.focus_only_on_problems", False):
CRITICAL_LABEL = 'critical'
if CRITICAL_LABEL in suggestion['label'].lower(): # we want the published labels to be less declarative
suggestion['label'] = 'possible issue'
if suggestion['one_sentence_summary'] in one_sentence_summary_list:
get_logger().debug(f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}")
continue
if 'const' in suggestion['suggestion_content'] and 'instead' in suggestion[
'suggestion_content'] and 'let' in suggestion['suggestion_content']:
get_logger().debug(
f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}")
continue
if ('existing_code' in suggestion) and ('improved_code' in suggestion):
suggestion = self._truncate_if_needed(suggestion)
one_sentence_summary_list.append(suggestion['one_sentence_summary'])
suggestion_list.append(suggestion)
else:
get_logger().info(
f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}")
except Exception as e:
get_logger().error(f"Error processing suggestion {i + 1}: {suggestion}, error: {e}")
data['code_suggestions'] = suggestion_list
return data
async def push_inline_code_suggestions(self, data):
code_suggestions = []
if not data['code_suggestions']:
get_logger().info('No suggestions found to improve this PR.')
if self.progress_response:
return self.git_provider.edit_comment(self.progress_response,
body='No suggestions found to improve this PR.')
else:
return self.git_provider.publish_comment('No suggestions found to improve this PR.')
for d in data['code_suggestions']:
try:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"suggestion: {d}")
relevant_file = d['relevant_file'].strip()
relevant_lines_start = int(d['relevant_lines_start']) # absolute position
relevant_lines_end = int(d['relevant_lines_end'])
content = d['suggestion_content'].rstrip()
new_code_snippet = d['improved_code'].rstrip()
label = d['label'].strip()
if new_code_snippet:
new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet)
if d.get('score'):
body = f"**Suggestion:** {content} [{label}, importance: {d.get('score')}]\n```suggestion\n" + new_code_snippet + "\n```"
else:
body = f"**Suggestion:** {content} [{label}]\n```suggestion\n" + new_code_snippet + "\n```"
code_suggestions.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_lines_start,
'relevant_lines_end': relevant_lines_end,
'original_suggestion': d})
except Exception:
get_logger().info(f"Could not parse suggestion: {d}")
is_successful = self.git_provider.publish_code_suggestions(code_suggestions)
if not is_successful:
get_logger().info("Failed to publish code suggestions, trying to publish each suggestion separately")
for code_suggestion in code_suggestions:
self.git_provider.publish_code_suggestions([code_suggestion])
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
else self.git_provider.get_diff_files()
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
if file.head_file:
file_lines = file.head_file.splitlines()
if relevant_lines_start > len(file_lines):
get_logger().warning(
"Could not dedent code snippet, because relevant_lines_start is out of range",
artifact={'filename': file.filename,
'file_content': file.head_file,
'relevant_lines_start': relevant_lines_start,
'new_code_snippet': new_code_snippet})
return new_code_snippet
else:
original_initial_line = file_lines[relevant_lines_start - 1]
else:
get_logger().warning("Could not dedent code snippet, because head_file is missing",
artifact={'filename': file.filename,
'relevant_lines_start': relevant_lines_start,
'new_code_snippet': new_code_snippet})
return new_code_snippet
break
if original_initial_line:
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e:
get_logger().error(f"Error when dedenting code snippet for file {relevant_file}, error: {e}")
return new_code_snippet
def _get_is_extended(self, args: list[str]) -> bool:
"""Check if extended mode should be enabled by the `--extended` flag or automatically according to the configuration"""
if any(["extended" in arg for arg in args]):
get_logger().info("Extended mode is enabled by the `--extended` flag")
return True
if get_settings().pr_code_suggestions.auto_extended_mode:
# get_logger().info("Extended mode is enabled automatically based on the configuration toggle")
return True
return False
def remove_line_numbers(self, patches_diff_list: List[str]) -> List[str]:
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
try:
self.patches_diff_list_no_line_numbers = []
for patches_diff in self.patches_diff_list:
patches_diff_lines = patches_diff.splitlines()
for i, line in enumerate(patches_diff_lines):
if line.strip():
if line.isnumeric():
patches_diff_lines[i] = ''
elif line[0].isdigit():
# find the first letter in the line that starts with a valid letter
for j, char in enumerate(line):
if not char.isdigit():
patches_diff_lines[i] = line[j + 1:]
break
self.patches_diff_list_no_line_numbers.append('\n'.join(patches_diff_lines))
return self.patches_diff_list_no_line_numbers
except Exception as e:
get_logger().error(f"Error removing line numbers from patches_diff_list, error: {e}")
return patches_diff_list
async def _prepare_prediction_extended(self, model: str) -> dict:
self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list)
if self.patches_diff_list:
get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}")
get_logger().debug(f"PR diff:", artifact=self.patches_diff_list)
# parallelize calls to AI:
if get_settings().pr_code_suggestions.parallel_calls:
prediction_list = await asyncio.gather(
*[self._get_prediction(model, patches_diff, patches_diff_no_line_numbers) for
patches_diff, patches_diff_no_line_numbers in
zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers)])
self.prediction_list = prediction_list
else:
prediction_list = []
for patches_diff, patches_diff_no_line_numbers in zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers):
prediction = await self._get_prediction(model, patches_diff, patches_diff_no_line_numbers)
prediction_list.append(prediction)
data = {"code_suggestions": []}
for j, predictions in enumerate(prediction_list): # each call adds an element to the list
if "code_suggestions" in predictions:
score_threshold = max(1, int(get_settings().pr_code_suggestions.suggestions_score_threshold))
for i, prediction in enumerate(predictions["code_suggestions"]):
try:
score = int(prediction.get("score", 1))
if score >= score_threshold:
data["code_suggestions"].append(prediction)
else:
get_logger().info(
f"Removing suggestions {i} from call {j}, because score is {score}, and score_threshold is {score_threshold}",
artifact=prediction)
except Exception as e:
get_logger().error(f"Error getting PR diff for suggestion {i} in call {j}, error: {e}",
artifact={"prediction": prediction})
self.data = data
else:
get_logger().warning(f"Empty PR diff list")
self.data = data = None
return data
def generate_summarized_suggestions(self, data: Dict) -> str:
try:
pr_body = "## PR 代码建议 ✨\n\n"
if len(data.get('code_suggestions', [])) == 0:
pr_body += "No suggestions found to improve this PR."
return pr_body
if get_settings().pr_code_suggestions.enable_intro_text and get_settings().config.is_auto_command:
pr_body += "Explore these optional code suggestions:\n\n"
language_extension_map_org = get_settings().language_extension_map_org
extension_to_language = {}
for language, extensions in language_extension_map_org.items():
for ext in extensions:
extension_to_language[ext] = language
pr_body += "<table>"
header = f"建议"
delta = 66
header += "&nbsp; " * delta
pr_body += f"""<thead><tr><td><strong>类别</strong></td><td align=left><strong>{header}</strong></td><td align=center><strong>影响</strong></td></tr>"""
pr_body += """<tbody>"""
suggestions_labels = dict()
# add all suggestions related to each label
for suggestion in data['code_suggestions']:
label = suggestion['label'].strip().strip("'").strip('"')
if label not in suggestions_labels:
suggestions_labels[label] = []
suggestions_labels[label].append(suggestion)
# sort suggestions_labels by the suggestion with the highest score
suggestions_labels = dict(
sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
# sort the suggestions inside each label group by score
for label, suggestions in suggestions_labels.items():
suggestions_labels[label] = sorted(suggestions, key=lambda x: x['score'], reverse=True)
counter_suggestions = 0
for label, suggestions in suggestions_labels.items():
num_suggestions = len(suggestions)
pr_body += f"""<tr><td rowspan={num_suggestions}>{label.capitalize()}</td>\n"""
for i, suggestion in enumerate(suggestions):
relevant_file = suggestion['relevant_file'].strip()
relevant_lines_start = int(suggestion['relevant_lines_start'])
relevant_lines_end = int(suggestion['relevant_lines_end'])
range_str = ""
if relevant_lines_start == relevant_lines_end:
range_str = f"[{relevant_lines_start}]"
else:
range_str = f"[{relevant_lines_start}-{relevant_lines_end}]"
try:
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
relevant_lines_end)
except:
code_snippet_link = ""
# add html table for each suggestion
suggestion_content = suggestion['suggestion_content'].rstrip()
CHAR_LIMIT_PER_LINE = 84
suggestion_content = insert_br_after_x_chars(suggestion_content, CHAR_LIMIT_PER_LINE)
# pr_body += f"<tr><td><details><summary>{suggestion_content}</summary>"
existing_code = suggestion['existing_code'].rstrip() + "\n"
improved_code = suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'),
improved_code.split('\n'), n=999)
patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
example_code = ""
example_code += f"```diff\n{patch.rstrip()}\n```\n"
if i == 0:
pr_body += f"""<td>\n\n"""
else:
pr_body += f"""<tr><td>\n\n"""
suggestion_summary = suggestion['one_sentence_summary'].strip().rstrip('.')
if "'<" in suggestion_summary and ">'" in suggestion_summary:
# escape the '<' and '>' characters, otherwise they are interpreted as html tags
get_logger().info(f"Escaped suggestion summary: {suggestion_summary}")
suggestion_summary = suggestion_summary.replace("'<", "`<")
suggestion_summary = suggestion_summary.replace(">'", ">`")
if '`' in suggestion_summary:
suggestion_summary = replace_code_tags(suggestion_summary)
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
pr_body += f"""
**{suggestion_content}**
[{relevant_file} {range_str}]({code_snippet_link})
{example_code.rstrip()}
"""
if suggestion.get('score_why'):
pr_body += f"<details><summary>严重性 [1-10]: {suggestion['score']}</summary>\n\n"
pr_body += f"__\n\nWhy: {suggestion['score_why']}\n\n"
pr_body += f"</details>"
pr_body += f"</details>"
# # add another column for 'score'
score_int = int(suggestion.get('score', 0))
score_str = f"{score_int}"
if get_settings().pr_code_suggestions.new_score_mechanism:
score_str = self.get_score_str(score_int)
pr_body += f"</td><td align=center>{score_str}\n\n"
pr_body += f"</td></tr>"
counter_suggestions += 1
# pr_body += "</details>"
# pr_body += """</td></tr>"""
pr_body += """</tr></tbody></table>"""
return pr_body
except Exception as e:
get_logger().info(f"Failed to publish summarized code suggestions, error: {e}")
return ""
def get_score_str(self, score: int) -> str:
th_high = get_settings().pr_code_suggestions.get('new_score_mechanism_th_high', 9)
th_medium = get_settings().pr_code_suggestions.get('new_score_mechanism_th_medium', 7)
if score >= th_high:
return ""
elif score >= th_medium:
return ""
else: # score < 7
return ""
async def self_reflect_on_suggestions(self,
suggestion_list: List,
patches_diff: str,
model: str,
prev_suggestions_str: str = "",
dedicated_prompt: str = "") -> str:
if not suggestion_list:
return ""
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list,
'suggestion_str': suggestion_str,
"diff": patches_diff,
'num_code_suggestions': len(suggestion_list),
'prev_suggestions_str': prev_suggestions_str,
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False)}
environment = Environment(undefined=StrictUndefined)
if dedicated_prompt:
system_prompt_reflect = environment.from_string(
get_settings().get(dedicated_prompt).system).render(variables)
user_prompt_reflect = environment.from_string(
get_settings().get(dedicated_prompt).user).render(variables)
else:
system_prompt_reflect = environment.from_string(
get_settings().pr_code_suggestions_reflect_prompt.system).render(variables)
user_prompt_reflect = environment.from_string(
get_settings().pr_code_suggestions_reflect_prompt.user).render(variables)
with get_logger().contextualize(command="self_reflect_on_suggestions"):
response_reflect, finish_reason_reflect = await self.ai_handler.chat_completion(model=model,
system=system_prompt_reflect,
user=user_prompt_reflect)
except Exception as e:
get_logger().info(f"Could not reflect on suggestions, error: {e}")
return ""
return response_reflect

Some files were not shown because too many files have changed in this diff Show More