first commit
This commit is contained in:
commit
578f5d9a3d
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
.idea/
|
||||
.lsp/
|
||||
.vscode/
|
||||
.env
|
||||
venv/
|
||||
pr_agent/settings/.secrets.toml
|
||||
__pycache__
|
||||
dist/
|
||||
*.egg-info/
|
||||
build/
|
||||
.DS_Store
|
||||
docs/.cache/
|
||||
.qodo
|
||||
db.sqlite3
|
||||
#pr_agent/
|
||||
21
Dockerfile
Normal file
21
Dockerfile
Normal file
@ -0,0 +1,21 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
ENV TZ=Asia/Shanghai
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . /app
|
||||
|
||||
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y procps net-tools apt-utils \
|
||||
&& ln -snf /usr/share/zoneinfo/${TZ} /etc/localtime && echo ${TZ} > /etc/timezone \
|
||||
&& pip install pipenv -i https://pypi.tuna.tsinghua.edu.cn/simple/
|
||||
|
||||
RUN pipenv sync && pipenv install --dev
|
||||
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
CMD ["sh", "start.sh"]
|
||||
28
Pipfile
Normal file
28
Pipfile
Normal file
@ -0,0 +1,28 @@
|
||||
[[source]]
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
verify_ssl = true
|
||||
name = "pip_conf_index_global"
|
||||
|
||||
[packages]
|
||||
django = "*"
|
||||
simplepro = "*"
|
||||
django-import-export = "*"
|
||||
litellm = "*"
|
||||
tenacity = "*"
|
||||
html2text = "*"
|
||||
starlette-context = "*"
|
||||
dynaconf = "*"
|
||||
loguru = "*"
|
||||
atlassian-python-api = "*"
|
||||
boto3 = "*"
|
||||
gitpython = "*"
|
||||
pygithub = "*"
|
||||
python-gitlab = "*"
|
||||
retry = "*"
|
||||
fastapi = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
||||
[requires]
|
||||
python_version = "3.12"
|
||||
python_full_version = "3.12.1"
|
||||
2153
Pipfile.lock
generated
Normal file
2153
Pipfile.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
0
apps/__init__.py
Normal file
0
apps/__init__.py
Normal file
0
apps/pr/__init__.py
Normal file
0
apps/pr/__init__.py
Normal file
40
apps/pr/admin.py
Normal file
40
apps/pr/admin.py
Normal file
@ -0,0 +1,40 @@
|
||||
from django.contrib import admin
|
||||
from simpleui.admin import AjaxAdmin
|
||||
|
||||
from pr import models
|
||||
|
||||
|
||||
@admin.register(models.AIConfig)
|
||||
class AIConfigAdmin(AjaxAdmin):
|
||||
"""Admin配置"""
|
||||
|
||||
list_display = ["api_base", "api_key", "llm_model"]
|
||||
top_html = ' <el-alert title="可配置多个AI模型厂商!" type="success"></el-alert>'
|
||||
|
||||
def save_model(self, request, obj, form, change):
|
||||
obj.create_by = request.user.username
|
||||
return super().save_model(request, obj, form, change)
|
||||
|
||||
|
||||
@admin.register(models.GitConfig)
|
||||
class GitConfigAdmin(AjaxAdmin):
|
||||
"""Admin配置"""
|
||||
|
||||
list_display = ["git_name", "git_type", "git_url", "access_token"]
|
||||
top_html = '<el-alert title="可配置多个Git服务上!" type="success"></el-alert>'
|
||||
|
||||
def save_model(self, request, obj, form, change):
|
||||
obj.create_by = request.user.username
|
||||
return super().save_model(request, obj, form, change)
|
||||
|
||||
|
||||
@admin.register(models.ProjectConfig)
|
||||
class ProjectConfigAdmin(AjaxAdmin):
|
||||
"""Admin配置"""
|
||||
|
||||
list_display = ["project_id", "project_name", "project_secret", "commands", "is_enable"]
|
||||
top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>'
|
||||
|
||||
def save_model(self, request, obj, form, change):
|
||||
obj.create_by = request.user.username
|
||||
return super().save_model(request, obj, form, change)
|
||||
7
apps/pr/apps.py
Normal file
7
apps/pr/apps.py
Normal file
@ -0,0 +1,7 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class PrConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "pr"
|
||||
verbose_name = "PR管理配置"
|
||||
0
apps/pr/management/__init__.py
Normal file
0
apps/pr/management/__init__.py
Normal file
0
apps/pr/management/commands/__init__.py
Normal file
0
apps/pr/management/commands/__init__.py
Normal file
19
apps/pr/management/commands/init_data.py
Normal file
19
apps/pr/management/commands/init_data.py
Normal file
@ -0,0 +1,19 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from pr import models
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "数据初始化"
|
||||
|
||||
def handle(self, *args, **options):
|
||||
ai_config, created = models.AIConfig.objects.get_or_create(
|
||||
api_base="http://110.40.24.85:3000/v1",
|
||||
api_key="sk-YLeQEboTsCEzfbmhbnytWRPyuC8Swe7OsBRKH30X26Jf1fsm",
|
||||
llm_model="o3-mini",
|
||||
)
|
||||
if created:
|
||||
print("初始化AI配置已创建")
|
||||
else:
|
||||
print("初始化AI配置已存在")
|
||||
|
||||
260
apps/pr/migrations/0001_initial.py
Normal file
260
apps/pr/migrations/0001_initial.py
Normal file
@ -0,0 +1,260 @@
|
||||
# Generated by Django 5.1.6 on 2025-02-25 13:55
|
||||
|
||||
import django.db.models.deletion
|
||||
import simplepro.components.fields
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = []
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="AIConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(primary_key=True, serialize=False)),
|
||||
(
|
||||
"uid",
|
||||
models.UUIDField(
|
||||
db_index=True,
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
verbose_name="UUID",
|
||||
),
|
||||
),
|
||||
(
|
||||
"create_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
auto_now_add=True, db_index=True, verbose_name="创建时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"update_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
auto_now=True, verbose_name="更新时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"delete_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
blank=True, null=True, verbose_name="删除时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"create_by",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=32, null=True, verbose_name="创建人"
|
||||
),
|
||||
),
|
||||
(
|
||||
"detail",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=200, null=True, verbose_name="备注信息"
|
||||
),
|
||||
),
|
||||
(
|
||||
"api_base",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=128, null=True, verbose_name="API(代理)地址"
|
||||
),
|
||||
),
|
||||
(
|
||||
"api_key",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=128, null=True, verbose_name="API密钥"
|
||||
),
|
||||
),
|
||||
(
|
||||
"llm_model",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=16, null=True, verbose_name="LLM模型"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "AI模型配置",
|
||||
"verbose_name_plural": "AI模型配置",
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GitConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(primary_key=True, serialize=False)),
|
||||
(
|
||||
"uid",
|
||||
models.UUIDField(
|
||||
db_index=True,
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
verbose_name="UUID",
|
||||
),
|
||||
),
|
||||
(
|
||||
"create_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
auto_now_add=True, db_index=True, verbose_name="创建时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"update_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
auto_now=True, verbose_name="更新时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"delete_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
blank=True, null=True, verbose_name="删除时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"create_by",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=32, null=True, verbose_name="创建人"
|
||||
),
|
||||
),
|
||||
(
|
||||
"detail",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=200, null=True, verbose_name="备注信息"
|
||||
),
|
||||
),
|
||||
(
|
||||
"git_name",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=16, null=True, verbose_name="Git名称"
|
||||
),
|
||||
),
|
||||
(
|
||||
"git_type",
|
||||
simplepro.components.fields.RadioField(
|
||||
choices=[
|
||||
("gitlab", "gitlab"),
|
||||
("github", "github"),
|
||||
("gitea", "gitea"),
|
||||
],
|
||||
default="gitlab",
|
||||
verbose_name="Git类型",
|
||||
),
|
||||
),
|
||||
(
|
||||
"git_url",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=128, null=True, verbose_name="Git地址"
|
||||
),
|
||||
),
|
||||
(
|
||||
"access_token",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=128, null=True, verbose_name="访问密钥"
|
||||
),
|
||||
),
|
||||
(
|
||||
"pr_ai",
|
||||
simplepro.components.fields.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="pr.aiconfig",
|
||||
verbose_name="AI模型",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Git服务配置",
|
||||
"verbose_name_plural": "Git服务配置",
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="ProjectConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(primary_key=True, serialize=False)),
|
||||
(
|
||||
"uid",
|
||||
models.UUIDField(
|
||||
db_index=True,
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
verbose_name="UUID",
|
||||
),
|
||||
),
|
||||
(
|
||||
"create_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
auto_now_add=True, db_index=True, verbose_name="创建时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"update_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
auto_now=True, verbose_name="更新时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"delete_at",
|
||||
simplepro.components.fields.DateTimeField(
|
||||
blank=True, null=True, verbose_name="删除时间"
|
||||
),
|
||||
),
|
||||
(
|
||||
"create_by",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=32, null=True, verbose_name="创建人"
|
||||
),
|
||||
),
|
||||
(
|
||||
"detail",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=200, null=True, verbose_name="备注信息"
|
||||
),
|
||||
),
|
||||
(
|
||||
"project_id",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=8, null=True, verbose_name="项目ID"
|
||||
),
|
||||
),
|
||||
(
|
||||
"project_name",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=16, null=True, verbose_name="项目名称"
|
||||
),
|
||||
),
|
||||
(
|
||||
"project_secret",
|
||||
simplepro.components.fields.CharField(
|
||||
blank=True, max_length=128, null=True, verbose_name="项目密钥"
|
||||
),
|
||||
),
|
||||
(
|
||||
"commands",
|
||||
simplepro.components.fields.CheckboxField(
|
||||
default=["/review"], max_length=256, verbose_name="默认命令"
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_enable",
|
||||
simplepro.components.fields.SwitchField(
|
||||
default=True, verbose_name="是否启用"
|
||||
),
|
||||
),
|
||||
(
|
||||
"git_config",
|
||||
simplepro.components.fields.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="pr.gitconfig",
|
||||
verbose_name="Git配置",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "项目配置",
|
||||
"verbose_name_plural": "项目配置",
|
||||
},
|
||||
),
|
||||
]
|
||||
0
apps/pr/migrations/__init__.py
Normal file
0
apps/pr/migrations/__init__.py
Normal file
94
apps/pr/models.py
Normal file
94
apps/pr/models.py
Normal file
@ -0,0 +1,94 @@
|
||||
from django.db import models
|
||||
from simplepro.components import fields
|
||||
|
||||
from public.models import BaseModel
|
||||
from utils import constant
|
||||
|
||||
|
||||
class AIConfig(BaseModel):
|
||||
"""
|
||||
AI模型配置表
|
||||
"""
|
||||
|
||||
api_base = fields.CharField(
|
||||
null=True, blank=True, max_length=128, verbose_name="API(代理)地址"
|
||||
)
|
||||
api_key = fields.CharField(
|
||||
null=True, blank=True, max_length=128, verbose_name="API密钥"
|
||||
)
|
||||
llm_model = fields.CharField(
|
||||
null=True, blank=True, max_length=16, verbose_name="LLM模型"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = "AI模型配置"
|
||||
verbose_name_plural = "AI模型配置"
|
||||
|
||||
|
||||
class GitConfig(BaseModel):
|
||||
"""
|
||||
Git服务配置表
|
||||
"""
|
||||
|
||||
pr_ai = fields.ForeignKey(
|
||||
AIConfig,
|
||||
null=True,
|
||||
blank=True,
|
||||
on_delete=models.SET_NULL,
|
||||
verbose_name="AI模型",
|
||||
)
|
||||
git_name = fields.CharField(
|
||||
null=True, blank=True, max_length=16, verbose_name="Git名称"
|
||||
)
|
||||
git_type = fields.RadioField(
|
||||
choices=constant.GIT_TYPE,
|
||||
default="gitlab",
|
||||
verbose_name="Git类型"
|
||||
)
|
||||
git_url = fields.CharField(
|
||||
null=True, blank=True, max_length=128, verbose_name="Git地址"
|
||||
)
|
||||
access_token = fields.CharField(
|
||||
null=True, blank=True, max_length=128, verbose_name="访问密钥"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = "Git服务配置"
|
||||
verbose_name_plural = "Git服务配置"
|
||||
|
||||
|
||||
class ProjectConfig(BaseModel):
|
||||
"""
|
||||
项目配置表
|
||||
"""
|
||||
git_config = fields.ForeignKey(
|
||||
GitConfig,
|
||||
null=True,
|
||||
blank=True,
|
||||
on_delete=models.SET_NULL,
|
||||
verbose_name="Git配置",
|
||||
)
|
||||
project_id = fields.CharField(
|
||||
null=True, blank=True, max_length=8, verbose_name="项目ID"
|
||||
)
|
||||
project_name = fields.CharField(
|
||||
null=True, blank=True, max_length=16, verbose_name="项目名称"
|
||||
)
|
||||
project_secret = fields.CharField(
|
||||
null=True, blank=True, max_length=128, verbose_name="项目密钥"
|
||||
)
|
||||
commands = fields.CheckboxField(
|
||||
choices=constant.DEFAULT_COMMANDS,
|
||||
default=["/review"],
|
||||
max_length=256,
|
||||
verbose_name="默认命令",
|
||||
)
|
||||
is_enable = fields.SwitchField(
|
||||
default=True,
|
||||
verbose_name="是否启用"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = "项目配置"
|
||||
verbose_name_plural = "项目配置"
|
||||
|
||||
3
apps/pr/tests.py
Normal file
3
apps/pr/tests.py
Normal file
@ -0,0 +1,3 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
24
apps/pr/urls.py
Normal file
24
apps/pr/urls.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
URL configuration for pr_manager project.
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/5.1/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from pr.views import WebHookView
|
||||
|
||||
urlpatterns = [
|
||||
path("webhook/", WebHookView.as_view()),
|
||||
]
|
||||
113
apps/pr/views.py
Normal file
113
apps/pr/views.py
Normal file
@ -0,0 +1,113 @@
|
||||
from pr import models
|
||||
from django.views import View
|
||||
from django.http import JsonResponse
|
||||
|
||||
from utils.pr_agent import cli
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils import constant
|
||||
|
||||
|
||||
def load_project_config(
|
||||
git_url,
|
||||
access_token,
|
||||
project_secret,
|
||||
openai_api_base,
|
||||
openai_key,
|
||||
llm_model
|
||||
):
|
||||
"""
|
||||
加载项目配置
|
||||
:param git_url: git服务器地址
|
||||
:param access_token: 用户访问密钥
|
||||
:param project_secret: 项目秘钥
|
||||
:param openai_api_base: openai api base
|
||||
:param openai_key: openai key
|
||||
:param llm_model: llm model
|
||||
:return:
|
||||
"""
|
||||
|
||||
return {
|
||||
"gitlab_url": git_url,
|
||||
"access_token": access_token,
|
||||
"secret": project_secret,
|
||||
"openai_api_base": openai_api_base,
|
||||
"openai_key": openai_key,
|
||||
"llm_model": llm_model
|
||||
}
|
||||
|
||||
|
||||
class WebHookView(View):
|
||||
def post(self, request):
|
||||
data = request.POST
|
||||
if not data:
|
||||
return JsonResponse(status=400, data={"error": "Invalid JSON"})
|
||||
|
||||
project_id = data.get('project', {}).get('id') or data.get('project_id')
|
||||
if not project_id:
|
||||
return JsonResponse(status=400, data={"error": "Missing project ID"})
|
||||
|
||||
project_config = models.ProjectConfig.objects.filter(project_id=project_id).first()
|
||||
# AI模型配置
|
||||
api_base = project_config.git_config.pr_ai.api_base
|
||||
api_key = project_config.git_config.pr_ai.api_key
|
||||
model = project_config.git_config.pr_ai.llm_model
|
||||
# Git服务器配置
|
||||
git_url = project_config.git_config.git_url
|
||||
git_type = project_config.git_config.git_type
|
||||
access_token = project_config.git_config.access_token
|
||||
project_secret = project_config.project_secret
|
||||
project_commands = project_config.commands
|
||||
|
||||
config = load_project_config(
|
||||
git_url=git_url,
|
||||
access_token=access_token,
|
||||
project_secret=project_secret,
|
||||
openai_api_base=api_base,
|
||||
openai_key=api_key,
|
||||
llm_model=model
|
||||
)
|
||||
token = request.headers.get('X-Gitlab-Token')
|
||||
if token:
|
||||
token = token.strip()
|
||||
expected_token = config["secret"].strip() if config["secret"] else None
|
||||
if token != expected_token:
|
||||
return JsonResponse(status=403, data={"error": "Invalid token"})
|
||||
|
||||
# 处理Merge Request事件
|
||||
if data.get('object_kind') == 'merge_request':
|
||||
merge_request = data.get('object_attributes', {})
|
||||
if merge_request.get('state') == 'opened':
|
||||
# 获取Merge Request的详细信息
|
||||
mr_url = merge_request.get('url')
|
||||
mr_action = merge_request.get('action')
|
||||
get_settings().set("config.git_provider", git_type)
|
||||
get_settings().set("gitlab.url", git_url)
|
||||
get_settings().set("gitlab.personal_access_token", access_token)
|
||||
get_settings().set("openai.api_base", api_base)
|
||||
get_settings().set("openai.key", api_key)
|
||||
get_settings().set("llm.model", model)
|
||||
|
||||
if mr_action == "update":
|
||||
old_rev = merge_request.get("oldrev")
|
||||
new_rev = merge_request.get("newrev")
|
||||
if old_rev == new_rev:
|
||||
return JsonResponse(status=200, data={"status": "ignored (no code change)"})
|
||||
|
||||
import threading
|
||||
|
||||
def run_cmd(command):
|
||||
cli.run_command(mr_url, command)
|
||||
|
||||
threads = []
|
||||
for cmd in project_commands:
|
||||
if cmd not in constant.DEFAULT_COMMANDS:
|
||||
continue
|
||||
t = threading.Thread(target=run_cmd, args=(cmd,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
# 记录MR信息
|
||||
return JsonResponse(status=200, data={"status": "review started"})
|
||||
return JsonResponse(status=400, data={"error": "Merge request URL not found or action not open"})
|
||||
return JsonResponse(status=200, data={"status": "ignored"})
|
||||
|
||||
|
||||
0
apps/public/__init__.py
Normal file
0
apps/public/__init__.py
Normal file
3
apps/public/admin.py
Normal file
3
apps/public/admin.py
Normal file
@ -0,0 +1,3 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
6
apps/public/apps.py
Normal file
6
apps/public/apps.py
Normal file
@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class PublicConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "public"
|
||||
0
apps/public/migrations/__init__.py
Normal file
0
apps/public/migrations/__init__.py
Normal file
64
apps/public/models.py
Normal file
64
apps/public/models.py
Normal file
@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from django.db import models
|
||||
from simplepro.components import fields
|
||||
|
||||
|
||||
class BaseQuerySet(models.QuerySet):
|
||||
def set_delete(self):
|
||||
return super(BaseQuerySet, self).update(delete_at=datetime.now())
|
||||
|
||||
|
||||
class BaseManager(models.Manager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.alive_only = kwargs.pop("alive_only", True)
|
||||
super(BaseManager, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
"""软删除"""
|
||||
if self.alive_only:
|
||||
# 为True,表示返回给admin的queryset为已过滤数据
|
||||
return BaseQuerySet(self.model).filter(delete_at__isnull=True)
|
||||
return BaseQuerySet(self.model)
|
||||
|
||||
|
||||
class BaseModel(models.Model):
|
||||
"""
|
||||
自定义Model基类
|
||||
"""
|
||||
|
||||
id = models.BigAutoField(primary_key=True)
|
||||
uid = models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, db_index=True, verbose_name="UUID"
|
||||
)
|
||||
create_at = fields.DateTimeField(
|
||||
auto_now_add=True, db_index=True, verbose_name="创建时间"
|
||||
)
|
||||
update_at = fields.DateTimeField(auto_now=True, verbose_name="更新时间")
|
||||
delete_at = fields.DateTimeField(null=True, blank=True, verbose_name="删除时间")
|
||||
create_by = fields.CharField(
|
||||
null=True, blank=True, max_length=32, verbose_name="创建人"
|
||||
)
|
||||
|
||||
detail = fields.CharField(
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=200,
|
||||
show_word_limit=True,
|
||||
prefix_icon="el-icon-edit",
|
||||
verbose_name="备注信息",
|
||||
placeholder="请输入备注信息(可为空)",
|
||||
)
|
||||
|
||||
objects = BaseManager() # 默认查看已存在数据
|
||||
all_objects = BaseManager(alive_only=False) # 返回已存在数据(包括已删除)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
ordering = ["-create_at"]
|
||||
|
||||
def set_delete(self):
|
||||
"""软删除"""
|
||||
self.delete_at = datetime.now()
|
||||
self.save()
|
||||
3
apps/public/tests.py
Normal file
3
apps/public/tests.py
Normal file
@ -0,0 +1,3 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
3
apps/public/views.py
Normal file
3
apps/public/views.py
Normal file
@ -0,0 +1,3 @@
|
||||
from django.shortcuts import render
|
||||
|
||||
# Create your views here.
|
||||
0
apps/utils/__init__.py
Normal file
0
apps/utils/__init__.py
Normal file
11
apps/utils/constant.py
Normal file
11
apps/utils/constant.py
Normal file
@ -0,0 +1,11 @@
|
||||
GIT_TYPE = (
|
||||
("gitlab", "gitlab"),
|
||||
("github", "github"),
|
||||
("gitea", "gitea")
|
||||
)
|
||||
|
||||
DEFAULT_COMMANDS = [
|
||||
"/review",
|
||||
"/describe",
|
||||
"/improve_code"
|
||||
]
|
||||
0
apps/utils/pr_agent/__init__.py
Normal file
0
apps/utils/pr_agent/__init__.py
Normal file
0
apps/utils/pr_agent/agent/__init__.py
Normal file
0
apps/utils/pr_agent/agent/__init__.py
Normal file
93
apps/utils/pr_agent/agent/pr_agent.py
Normal file
93
apps/utils/pr_agent/agent/pr_agent.py
Normal file
@ -0,0 +1,93 @@
|
||||
import shlex
|
||||
from functools import partial
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from utils.pr_agent.algo.cli_args import CliArgs
|
||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
from utils.pr_agent.tools.pr_add_docs import PRAddDocs
|
||||
from utils.pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||
from utils.pr_agent.tools.pr_config import PRConfig
|
||||
from utils.pr_agent.tools.pr_description import PRDescription
|
||||
from utils.pr_agent.tools.pr_generate_labels import PRGenerateLabels
|
||||
from utils.pr_agent.tools.pr_help_message import PRHelpMessage
|
||||
from utils.pr_agent.tools.pr_line_questions import PR_LineQuestions
|
||||
from utils.pr_agent.tools.pr_questions import PRQuestions
|
||||
from utils.pr_agent.tools.pr_reviewer import PRReviewer
|
||||
from utils.pr_agent.tools.pr_similar_issue import PRSimilarIssue
|
||||
from utils.pr_agent.tools.pr_update_changelog import PRUpdateChangelog
|
||||
|
||||
command2class = {
|
||||
"auto_review": PRReviewer,
|
||||
"answer": PRReviewer,
|
||||
"review": PRReviewer,
|
||||
"review_pr": PRReviewer,
|
||||
"describe": PRDescription,
|
||||
"describe_pr": PRDescription,
|
||||
"improve": PRCodeSuggestions,
|
||||
"improve_code": PRCodeSuggestions,
|
||||
"ask": PRQuestions,
|
||||
"ask_question": PRQuestions,
|
||||
"ask_line": PR_LineQuestions,
|
||||
"update_changelog": PRUpdateChangelog,
|
||||
"config": PRConfig,
|
||||
"settings": PRConfig,
|
||||
"help": PRHelpMessage,
|
||||
"similar_issue": PRSimilarIssue,
|
||||
"add_docs": PRAddDocs,
|
||||
"generate_labels": PRGenerateLabels,
|
||||
}
|
||||
|
||||
commands = list(command2class.keys())
|
||||
|
||||
|
||||
class PRAgent:
|
||||
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
self.ai_handler = ai_handler # will be initialized in run_action
|
||||
|
||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||
# First, apply repo specific settings if exists
|
||||
apply_repo_settings(pr_url)
|
||||
|
||||
# Then, apply user specific settings if exists
|
||||
if isinstance(request, str):
|
||||
request = request.replace("'", "\\'")
|
||||
lexer = shlex.shlex(request, posix=True)
|
||||
lexer.whitespace_split = True
|
||||
action, *args = list(lexer)
|
||||
else:
|
||||
action, *args = request
|
||||
|
||||
# validate args
|
||||
is_valid, arg = CliArgs.validate_user_args(args)
|
||||
if not is_valid:
|
||||
get_logger().error(
|
||||
f"CLI argument for param '{arg}' is forbidden. Use instead a configuration file."
|
||||
)
|
||||
return False
|
||||
|
||||
# Update settings from args
|
||||
args = update_settings_from_args(args)
|
||||
|
||||
action = action.lstrip("/").lower()
|
||||
if action not in command2class:
|
||||
get_logger().error(f"Unknown command: {action}")
|
||||
return False
|
||||
with get_logger().contextualize(command=action, pr_url=pr_url):
|
||||
get_logger().info("PR-Agent request handler started", analytics=True)
|
||||
if action == "answer":
|
||||
if notify:
|
||||
notify()
|
||||
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
|
||||
elif action == "auto_review":
|
||||
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
|
||||
elif action in command2class:
|
||||
if notify:
|
||||
notify()
|
||||
|
||||
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
103
apps/utils/pr_agent/algo/__init__.py
Normal file
103
apps/utils/pr_agent/algo/__init__.py
Normal file
@ -0,0 +1,103 @@
|
||||
MAX_TOKENS = {
|
||||
'text-embedding-ada-002': 8000,
|
||||
'gpt-3.5-turbo': 16000,
|
||||
'gpt-3.5-turbo-0125': 16000,
|
||||
'gpt-3.5-turbo-0613': 4000,
|
||||
'gpt-3.5-turbo-1106': 16000,
|
||||
'gpt-3.5-turbo-16k': 16000,
|
||||
'gpt-3.5-turbo-16k-0613': 16000,
|
||||
'gpt-4': 8000,
|
||||
'gpt-4-0613': 8000,
|
||||
'gpt-4-32k': 32000,
|
||||
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4o': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4o-2024-05-13': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4-turbo-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4-turbo-2024-04-09': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4-turbo': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4o-mini': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4o-mini-2024-07-18': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4o-2024-08-06': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4o-2024-11-20': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'o1-mini': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'o1-mini-2024-09-12': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'o1-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'o1-preview-2024-09-12': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'o1-2024-12-17': 204800, # 200K, but may be limited by config.max_model_tokens
|
||||
'o1': 204800, # 200K, but may be limited by config.max_model_tokens
|
||||
'o3-mini': 204800, # 200K, but may be limited by config.max_model_tokens
|
||||
'o3-mini-2025-01-31': 204800, # 200K, but may be limited by config.max_model_tokens
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'command-nightly': 4096,
|
||||
'deepseek/deepseek-chat': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'deepseek/deepseek-reasoner': 64000, # 64K, but may be limited by config.max_model_tokens
|
||||
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
|
||||
'meta-llama/Llama-2-7b-chat-hf': 4096,
|
||||
'vertex_ai/codechat-bison': 6144,
|
||||
'vertex_ai/codechat-bison-32k': 32000,
|
||||
'vertex_ai/claude-3-haiku@20240307': 100000,
|
||||
'vertex_ai/claude-3-5-haiku@20241022': 100000,
|
||||
'vertex_ai/claude-3-sonnet@20240229': 100000,
|
||||
'vertex_ai/claude-3-opus@20240229': 100000,
|
||||
'vertex_ai/claude-3-5-sonnet@20240620': 100000,
|
||||
'vertex_ai/claude-3-5-sonnet-v2@20241022': 100000,
|
||||
'vertex_ai/gemini-1.5-pro': 1048576,
|
||||
'vertex_ai/gemini-1.5-flash': 1048576,
|
||||
'vertex_ai/gemini-2.0-flash': 1048576,
|
||||
'vertex_ai/gemma2': 8200,
|
||||
'gemini/gemini-1.5-pro': 1048576,
|
||||
'gemini/gemini-1.5-flash': 1048576,
|
||||
'gemini/gemini-2.0-flash': 1048576,
|
||||
'codechat-bison': 6144,
|
||||
'codechat-bison-32k': 32000,
|
||||
'anthropic.claude-instant-v1': 100000,
|
||||
'anthropic.claude-v1': 100000,
|
||||
'anthropic.claude-v2': 100000,
|
||||
'anthropic/claude-3-opus-20240229': 100000,
|
||||
'anthropic/claude-3-5-sonnet-20240620': 100000,
|
||||
'anthropic/claude-3-5-sonnet-20241022': 100000,
|
||||
'anthropic/claude-3-5-haiku-20241022': 100000,
|
||||
'bedrock/anthropic.claude-instant-v1': 100000,
|
||||
'bedrock/anthropic.claude-v2': 100000,
|
||||
'bedrock/anthropic.claude-v2:1': 100000,
|
||||
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000,
|
||||
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
|
||||
'bedrock/anthropic.claude-3-5-haiku-20241022-v1:0': 100000,
|
||||
'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000,
|
||||
'bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0': 100000,
|
||||
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0": 100000,
|
||||
'claude-3-5-sonnet': 100000,
|
||||
'groq/llama3-8b-8192': 8192,
|
||||
'groq/llama3-70b-8192': 8192,
|
||||
'groq/llama-3.1-8b-instant': 8192,
|
||||
'groq/llama-3.3-70b-versatile': 128000,
|
||||
'groq/mixtral-8x7b-32768': 32768,
|
||||
'groq/gemma2-9b-it': 8192,
|
||||
'ollama/llama3': 4096,
|
||||
'watsonx/meta-llama/llama-3-8b-instruct': 4096,
|
||||
"watsonx/meta-llama/llama-3-70b-instruct": 4096,
|
||||
"watsonx/meta-llama/llama-3-405b-instruct": 16384,
|
||||
"watsonx/ibm/granite-13b-chat-v2": 8191,
|
||||
"watsonx/ibm/granite-34b-code-instruct": 8191,
|
||||
"watsonx/mistralai/mistral-large": 32768,
|
||||
}
|
||||
|
||||
USER_MESSAGE_ONLY_MODELS = [
|
||||
"deepseek/deepseek-reasoner",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview"
|
||||
]
|
||||
|
||||
NO_SUPPORT_TEMPERATURE_MODELS = [
|
||||
"deepseek/deepseek-reasoner",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1",
|
||||
"o1-2024-12-17",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31",
|
||||
"o1-preview"
|
||||
]
|
||||
0
apps/utils/pr_agent/algo/ai_handlers/__init__.py
Normal file
0
apps/utils/pr_agent/algo/ai_handlers/__init__.py
Normal file
28
apps/utils/pr_agent/algo/ai_handlers/base_ai_handler.py
Normal file
28
apps/utils/pr_agent/algo/ai_handlers/base_ai_handler.py
Normal file
@ -0,0 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseAiHandler(ABC):
|
||||
"""
|
||||
This class defines the interface for an AI handler to be used by the PR Agents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def deployment_id(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
||||
"""
|
||||
This method should be implemented to return a chat completion from the AI model.
|
||||
Args:
|
||||
model (str): the name of the model to use for the chat completion
|
||||
system (str): the system message string to use for the chat completion
|
||||
user (str): the user message string to use for the chat completion
|
||||
temperature (float): the temperature to use for the chat completion
|
||||
"""
|
||||
pass
|
||||
74
apps/utils/pr_agent/algo/ai_handlers/langchain_ai_handler.py
Normal file
74
apps/utils/pr_agent/algo/ai_handlers/langchain_ai_handler.py
Normal file
@ -0,0 +1,74 @@
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
|
||||
pass
|
||||
|
||||
from openai import APIError, RateLimitError, Timeout
|
||||
from retry import retry
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
OPENAI_RETRIES = 5
|
||||
|
||||
|
||||
class LangChainOpenAIHandler(BaseAiHandler):
|
||||
def __init__(self):
|
||||
# Initialize OpenAIHandler specific attributes here
|
||||
super().__init__()
|
||||
self.azure = get_settings().get("OPENAI.API_TYPE", "").lower() == "azure"
|
||||
|
||||
# Create a default unused chat object to trigger early validation
|
||||
self._create_chat(self.deployment_id)
|
||||
|
||||
def chat(self, messages: list, model: str, temperature: float):
|
||||
chat = self._create_chat(self.deployment_id)
|
||||
return chat.invoke(input=messages, model=model, temperature=temperature)
|
||||
|
||||
@property
|
||||
def deployment_id(self):
|
||||
"""
|
||||
Returns the deployment ID for the OpenAI API.
|
||||
"""
|
||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
|
||||
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||
try:
|
||||
messages = [SystemMessage(content=system), HumanMessage(content=user)]
|
||||
|
||||
# get a chat completion from the formatted messages
|
||||
resp = self.chat(messages, model=model, temperature=temperature)
|
||||
finish_reason = "completed"
|
||||
return resp.content, finish_reason
|
||||
|
||||
except (Exception) as e:
|
||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||
raise e
|
||||
|
||||
def _create_chat(self, deployment_id=None):
|
||||
try:
|
||||
if self.azure:
|
||||
# using a partial function so we can set the deployment_id later to support fallback_deployments
|
||||
# but still need to access the other settings now so we can raise a proper exception if they're missing
|
||||
return AzureChatOpenAI(
|
||||
openai_api_key=get_settings().openai.key,
|
||||
openai_api_version=get_settings().openai.api_version,
|
||||
azure_deployment=deployment_id,
|
||||
azure_endpoint=get_settings().openai.api_base,
|
||||
)
|
||||
else:
|
||||
# for llms that compatible with openai, should use custom api base
|
||||
openai_api_base = get_settings().get("OPENAI.API_BASE", None)
|
||||
if openai_api_base is None or len(openai_api_base) == 0:
|
||||
return ChatOpenAI(openai_api_key=get_settings().openai.key)
|
||||
else:
|
||||
return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base)
|
||||
except AttributeError as e:
|
||||
if getattr(e, "name"):
|
||||
raise ValueError(f"OpenAI {e.name} is required") from e
|
||||
else:
|
||||
raise e
|
||||
277
apps/utils/pr_agent/algo/ai_handlers/litellm_ai_handler.py
Normal file
277
apps/utils/pr_agent/algo/ai_handlers/litellm_ai_handler.py
Normal file
@ -0,0 +1,277 @@
|
||||
import os
|
||||
|
||||
import litellm
|
||||
import openai
|
||||
import requests
|
||||
from litellm import acompletion
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
from utils.pr_agent.algo import NO_SUPPORT_TEMPERATURE_MODELS, USER_MESSAGE_ONLY_MODELS
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.algo.utils import get_version
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
OPENAI_RETRIES = 5
|
||||
|
||||
|
||||
class LiteLLMAIHandler(BaseAiHandler):
|
||||
"""
|
||||
This class handles interactions with the OpenAI API for chat completions.
|
||||
It initializes the API key and other settings from a configuration file,
|
||||
and provides a method for performing chat completions using the OpenAI ChatCompletion API.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the OpenAI API key and other settings from a configuration file.
|
||||
Raises a ValueError if the OpenAI key is missing.
|
||||
"""
|
||||
self.azure = False
|
||||
self.api_base = None
|
||||
self.repetition_penalty = None
|
||||
if get_settings().get("OPENAI.KEY", None):
|
||||
openai.api_key = get_settings().openai.key
|
||||
litellm.openai_key = get_settings().openai.key
|
||||
elif 'OPENAI_API_KEY' not in os.environ:
|
||||
litellm.api_key = "dummy_key"
|
||||
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
|
||||
assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete"
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
|
||||
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
|
||||
if get_settings().get("litellm.use_client"):
|
||||
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
|
||||
assert litellm_token, "LITELLM_TOKEN is required"
|
||||
os.environ["LITELLM_TOKEN"] = litellm_token
|
||||
litellm.use_client = True
|
||||
if get_settings().get("LITELLM.DROP_PARAMS", None):
|
||||
litellm.drop_params = get_settings().litellm.drop_params
|
||||
if get_settings().get("LITELLM.SUCCESS_CALLBACK", None):
|
||||
litellm.success_callback = get_settings().litellm.success_callback
|
||||
if get_settings().get("LITELLM.FAILURE_CALLBACK", None):
|
||||
litellm.failure_callback = get_settings().litellm.failure_callback
|
||||
if get_settings().get("LITELLM.SERVICE_CALLBACK", None):
|
||||
litellm.service_callback = get_settings().litellm.service_callback
|
||||
if get_settings().get("OPENAI.ORG", None):
|
||||
litellm.organization = get_settings().openai.org
|
||||
if get_settings().get("OPENAI.API_TYPE", None):
|
||||
if get_settings().openai.api_type == "azure":
|
||||
self.azure = True
|
||||
litellm.azure_key = get_settings().openai.key
|
||||
if get_settings().get("OPENAI.API_VERSION", None):
|
||||
litellm.api_version = get_settings().openai.api_version
|
||||
if get_settings().get("OPENAI.API_BASE", None):
|
||||
litellm.api_base = get_settings().openai.api_base
|
||||
if get_settings().get("ANTHROPIC.KEY", None):
|
||||
litellm.anthropic_key = get_settings().anthropic.key
|
||||
if get_settings().get("COHERE.KEY", None):
|
||||
litellm.cohere_key = get_settings().cohere.key
|
||||
if get_settings().get("GROQ.KEY", None):
|
||||
litellm.api_key = get_settings().groq.key
|
||||
if get_settings().get("REPLICATE.KEY", None):
|
||||
litellm.replicate_key = get_settings().replicate.key
|
||||
if get_settings().get("HUGGINGFACE.KEY", None):
|
||||
litellm.huggingface_key = get_settings().huggingface.key
|
||||
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
|
||||
litellm.api_base = get_settings().huggingface.api_base
|
||||
self.api_base = get_settings().huggingface.api_base
|
||||
if get_settings().get("OLLAMA.API_BASE", None):
|
||||
litellm.api_base = get_settings().ollama.api_base
|
||||
self.api_base = get_settings().ollama.api_base
|
||||
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
|
||||
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
|
||||
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
|
||||
litellm.vertex_project = get_settings().vertexai.vertex_project
|
||||
litellm.vertex_location = get_settings().get(
|
||||
"VERTEXAI.VERTEX_LOCATION", None
|
||||
)
|
||||
# Google AI Studio
|
||||
# SEE https://docs.litellm.ai/docs/providers/gemini
|
||||
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
|
||||
os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key
|
||||
|
||||
# Support deepseek models
|
||||
if get_settings().get("DEEPSEEK.KEY", None):
|
||||
os.environ['DEEPSEEK_API_KEY'] = get_settings().get("DEEPSEEK.KEY")
|
||||
|
||||
# Models that only use user meessage
|
||||
self.user_message_only_models = USER_MESSAGE_ONLY_MODELS
|
||||
|
||||
# Model that doesn't support temperature argument
|
||||
self.no_support_temperature_models = NO_SUPPORT_TEMPERATURE_MODELS
|
||||
|
||||
def prepare_logs(self, response, system, user, resp, finish_reason):
|
||||
response_log = response.dict().copy()
|
||||
response_log['system'] = system
|
||||
response_log['user'] = user
|
||||
response_log['output'] = resp
|
||||
response_log['finish_reason'] = finish_reason
|
||||
if hasattr(self, 'main_pr_language'):
|
||||
response_log['main_pr_language'] = self.main_pr_language
|
||||
else:
|
||||
response_log['main_pr_language'] = 'unknown'
|
||||
return response_log
|
||||
|
||||
def add_litellm_callbacks(selfs, kwargs) -> dict:
|
||||
captured_extra = []
|
||||
|
||||
def capture_logs(message):
|
||||
# Parsing the log message and context
|
||||
record = message.record
|
||||
log_entry = {}
|
||||
if record.get('extra', None).get('command', None) is not None:
|
||||
log_entry.update({"command": record['extra']["command"]})
|
||||
if record.get('extra', {}).get('pr_url', None) is not None:
|
||||
log_entry.update({"pr_url": record['extra']["pr_url"]})
|
||||
|
||||
# Append the log entry to the captured_logs list
|
||||
captured_extra.append(log_entry)
|
||||
|
||||
# Adding the custom sink to Loguru
|
||||
handler_id = get_logger().add(capture_logs)
|
||||
get_logger().debug("Capturing logs for litellm callbacks")
|
||||
get_logger().remove(handler_id)
|
||||
|
||||
context = captured_extra[0] if len(captured_extra) > 0 else None
|
||||
|
||||
command = context.get("command", "unknown")
|
||||
pr_url = context.get("pr_url", "unknown")
|
||||
git_provider = get_settings().config.git_provider
|
||||
|
||||
metadata = dict()
|
||||
callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback
|
||||
if "langfuse" in callbacks:
|
||||
metadata.update({
|
||||
"trace_name": command,
|
||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||
"trace_metadata": {
|
||||
"command": command,
|
||||
"pr_url": pr_url,
|
||||
},
|
||||
})
|
||||
if "langsmith" in callbacks:
|
||||
metadata.update({
|
||||
"run_name": command,
|
||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||
"extra": {
|
||||
"metadata": {
|
||||
"command": command,
|
||||
"pr_url": pr_url,
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
# Adding the captured logs to the kwargs
|
||||
kwargs["metadata"] = metadata
|
||||
|
||||
return kwargs
|
||||
|
||||
@property
|
||||
def deployment_id(self):
|
||||
"""
|
||||
Returns the deployment ID for the OpenAI API.
|
||||
"""
|
||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError
|
||||
stop=stop_after_attempt(OPENAI_RETRIES)
|
||||
)
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
||||
try:
|
||||
resp, finish_reason = None, None
|
||||
deployment_id = self.deployment_id
|
||||
if self.azure:
|
||||
model = 'azure/' + model
|
||||
if 'claude' in model and not system:
|
||||
system = "No system prompt provided"
|
||||
get_logger().warning(
|
||||
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.")
|
||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
|
||||
if img_path:
|
||||
try:
|
||||
# check if the image link is alive
|
||||
r = requests.head(img_path, allow_redirects=True)
|
||||
if r.status_code == 404:
|
||||
error_msg = f"The image link is not [alive](img_path).\nPlease repost the original image as a comment, and send the question again with 'quote reply' (see [instructions](https://pr-agent-docs.codium.ai/tools/ask/#ask-on-images-using-the-pr-code-as-context))."
|
||||
get_logger().error(error_msg)
|
||||
return f"{error_msg}", "error"
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error fetching image: {img_path}", e)
|
||||
return f"Error fetching image: {img_path}", "error"
|
||||
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
|
||||
{"type": "image_url", "image_url": {"url": img_path}}]
|
||||
|
||||
# Currently, some models do not support a separate system and user prompts
|
||||
if model in self.user_message_only_models or get_settings().config.custom_reasoning_model:
|
||||
user = f"{system}\n\n\n{user}"
|
||||
system = ""
|
||||
get_logger().info(f"Using model {model}, combining system and user prompts")
|
||||
messages = [{"role": "user", "content": user}]
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"deployment_id": deployment_id,
|
||||
"messages": messages,
|
||||
"timeout": get_settings().config.ai_timeout,
|
||||
"api_base": self.api_base,
|
||||
}
|
||||
else:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"deployment_id": deployment_id,
|
||||
"messages": messages,
|
||||
"timeout": get_settings().config.ai_timeout,
|
||||
"api_base": self.api_base,
|
||||
}
|
||||
|
||||
# Add temperature only if model supports it
|
||||
if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model:
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
if get_settings().litellm.get("enable_callbacks", False):
|
||||
kwargs = self.add_litellm_callbacks(kwargs)
|
||||
|
||||
seed = get_settings().config.get("seed", -1)
|
||||
if temperature > 0 and seed >= 0:
|
||||
raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0")
|
||||
elif seed >= 0:
|
||||
get_logger().info(f"Using fixed seed of {seed}")
|
||||
kwargs["seed"] = seed
|
||||
|
||||
if self.repetition_penalty:
|
||||
kwargs["repetition_penalty"] = self.repetition_penalty
|
||||
|
||||
get_logger().debug("Prompts", artifact={"system": system, "user": user})
|
||||
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"\nSystem prompt:\n{system}")
|
||||
get_logger().info(f"\nUser prompt:\n{user}")
|
||||
|
||||
response = await acompletion(**kwargs)
|
||||
except (openai.APIError, openai.APITimeoutError) as e:
|
||||
get_logger().warning(f"Error during LLM inference: {e}")
|
||||
raise
|
||||
except (openai.RateLimitError) as e:
|
||||
get_logger().error(f"Rate limit error during LLM inference: {e}")
|
||||
raise
|
||||
except (Exception) as e:
|
||||
get_logger().warning(f"Unknown error during LLM inference: {e}")
|
||||
raise openai.APIError from e
|
||||
if response is None or len(response["choices"]) == 0:
|
||||
raise openai.APIError
|
||||
else:
|
||||
resp = response["choices"][0]['message']['content']
|
||||
finish_reason = response["choices"][0]["finish_reason"]
|
||||
get_logger().debug(f"\nAI response:\n{resp}")
|
||||
|
||||
# log the full response for debugging
|
||||
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
|
||||
get_logger().debug("Full_response", artifact=response_log)
|
||||
|
||||
# for CLI debugging
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"\nAI response:\n{resp}")
|
||||
|
||||
return resp, finish_reason
|
||||
67
apps/utils/pr_agent/algo/ai_handlers/openai_ai_handler.py
Normal file
67
apps/utils/pr_agent/algo/ai_handlers/openai_ai_handler.py
Normal file
@ -0,0 +1,67 @@
|
||||
from os import environ
|
||||
import openai
|
||||
from openai import APIError, AsyncOpenAI, RateLimitError, Timeout
|
||||
from retry import retry
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
OPENAI_RETRIES = 5
|
||||
|
||||
|
||||
class OpenAIHandler(BaseAiHandler):
|
||||
def __init__(self):
|
||||
# Initialize OpenAIHandler specific attributes here
|
||||
try:
|
||||
super().__init__()
|
||||
environ["OPENAI_API_KEY"] = get_settings().openai.key
|
||||
if get_settings().get("OPENAI.ORG", None):
|
||||
openai.organization = get_settings().openai.org
|
||||
if get_settings().get("OPENAI.API_TYPE", None):
|
||||
if get_settings().openai.api_type == "azure":
|
||||
self.azure = True
|
||||
openai.azure_key = get_settings().openai.key
|
||||
if get_settings().get("OPENAI.API_VERSION", None):
|
||||
openai.api_version = get_settings().openai.api_version
|
||||
if get_settings().get("OPENAI.API_BASE", None):
|
||||
environ["OPENAI_BASE_URL"] = get_settings().openai.api_base
|
||||
|
||||
except AttributeError as e:
|
||||
raise ValueError("OpenAI key is required") from e
|
||||
|
||||
@property
|
||||
def deployment_id(self):
|
||||
"""
|
||||
Returns the deployment ID for the OpenAI API.
|
||||
"""
|
||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
|
||||
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||
try:
|
||||
get_logger().info("System: ", system)
|
||||
get_logger().info("User: ", user)
|
||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
client = AsyncOpenAI()
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
)
|
||||
resp = chat_completion.choices[0].message.content
|
||||
finish_reason = chat_completion.choices[0].finish_reason
|
||||
usage = chat_completion.usage
|
||||
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
|
||||
model=model, usage=usage)
|
||||
return resp, finish_reason
|
||||
except (APIError, Timeout) as e:
|
||||
get_logger().error("Error during OpenAI inference: ", e)
|
||||
raise
|
||||
except (RateLimitError) as e:
|
||||
get_logger().error("Rate limit error during OpenAI inference: ", e)
|
||||
raise
|
||||
except (Exception) as e:
|
||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||
raise
|
||||
34
apps/utils/pr_agent/algo/cli_args.py
Normal file
34
apps/utils/pr_agent/algo/cli_args.py
Normal file
@ -0,0 +1,34 @@
|
||||
from base64 import b64decode
|
||||
import hashlib
|
||||
|
||||
class CliArgs:
|
||||
@staticmethod
|
||||
def validate_user_args(args: list) -> (bool, str):
|
||||
try:
|
||||
if not args:
|
||||
return True, ""
|
||||
|
||||
# decode forbidden args
|
||||
_encoded_args = 'ZW5hYmxlX2F1dG9fYXBwcm92YWw=:YXBwcm92ZV9wcl9vbl9zZWxmX3Jldmlldw==:YmFzZV91cmw=:dXJs:YXBwX25hbWU=:c2VjcmV0X3Byb3ZpZGVy:Z2l0X3Byb3ZpZGVy:c2tpcF9rZXlz:b3BlbmFpLmtleQ==:QU5BTFlUSUNTX0ZPTERFUg==:dXJp:YXBwX2lk:d2ViaG9va19zZWNyZXQ=:YmVhcmVyX3Rva2Vu:UEVSU09OQUxfQUNDRVNTX1RPS0VO:b3ZlcnJpZGVfZGVwbG95bWVudF90eXBl:cHJpdmF0ZV9rZXk=:bG9jYWxfY2FjaGVfcGF0aA==:ZW5hYmxlX2xvY2FsX2NhY2hl:amlyYV9iYXNlX3VybA==:YXBpX2Jhc2U=:YXBpX3R5cGU=:YXBpX3ZlcnNpb24=:c2tpcF9rZXlz'
|
||||
forbidden_cli_args = []
|
||||
for e in _encoded_args.split(':'):
|
||||
forbidden_cli_args.append(b64decode(e).decode())
|
||||
|
||||
# lowercase all forbidden args
|
||||
for i, _ in enumerate(forbidden_cli_args):
|
||||
forbidden_cli_args[i] = forbidden_cli_args[i].lower()
|
||||
if '.' not in forbidden_cli_args[i]:
|
||||
forbidden_cli_args[i] = '.' + forbidden_cli_args[i]
|
||||
|
||||
for arg in args:
|
||||
if arg.startswith('--'):
|
||||
arg_word = arg.lower()
|
||||
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key
|
||||
for forbidden_arg_word in forbidden_cli_args:
|
||||
if forbidden_arg_word in arg_word:
|
||||
return False, forbidden_arg_word
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
65
apps/utils/pr_agent/algo/file_filter.py
Normal file
65
apps/utils/pr_agent/algo/file_filter.py
Normal file
@ -0,0 +1,65 @@
|
||||
import fnmatch
|
||||
import re
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def filter_ignored(files, platform = 'github'):
|
||||
"""
|
||||
Filter out files that match the ignore patterns.
|
||||
"""
|
||||
|
||||
try:
|
||||
# load regex patterns, and translate glob patterns to regex
|
||||
patterns = get_settings().ignore.regex
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
glob_setting = get_settings().ignore.glob
|
||||
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
|
||||
glob_setting = glob_setting.strip('[]').split(",")
|
||||
patterns += [fnmatch.translate(glob) for glob in glob_setting]
|
||||
|
||||
# compile all valid patterns
|
||||
compiled_patterns = []
|
||||
for r in patterns:
|
||||
try:
|
||||
compiled_patterns.append(re.compile(r))
|
||||
except re.error:
|
||||
pass
|
||||
|
||||
# keep filenames that _don't_ match the ignore regex
|
||||
if files and isinstance(files, list):
|
||||
for r in compiled_patterns:
|
||||
if platform == 'github':
|
||||
files = [f for f in files if (f.filename and not r.match(f.filename))]
|
||||
elif platform == 'bitbucket':
|
||||
# files = [f for f in files if (f.new.path and not r.match(f.new.path))]
|
||||
files_o = []
|
||||
for f in files:
|
||||
if hasattr(f, 'new'):
|
||||
if f.new and f.new.path and not r.match(f.new.path):
|
||||
files_o.append(f)
|
||||
continue
|
||||
if hasattr(f, 'old'):
|
||||
if f.old and f.old.path and not r.match(f.old.path):
|
||||
files_o.append(f)
|
||||
continue
|
||||
files = files_o
|
||||
elif platform == 'gitlab':
|
||||
# files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))]
|
||||
files_o = []
|
||||
for f in files:
|
||||
if 'new_path' in f and f['new_path'] and not r.match(f['new_path']):
|
||||
files_o.append(f)
|
||||
continue
|
||||
if 'old_path' in f and f['old_path'] and not r.match(f['old_path']):
|
||||
files_o.append(f)
|
||||
continue
|
||||
files = files_o
|
||||
elif platform == 'azure':
|
||||
files = [f for f in files if not r.match(f)]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Could not filter file list: {e}")
|
||||
|
||||
return files
|
||||
414
apps/utils/pr_agent/algo/git_patch_processing.py
Normal file
414
apps/utils/pr_agent/algo/git_patch_processing.py
Normal file
@ -0,0 +1,414 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
|
||||
patch_extra_lines_after=0, filename: str = "") -> str:
|
||||
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
|
||||
return patch_str
|
||||
|
||||
original_file_str = decode_if_bytes(original_file_str)
|
||||
if not original_file_str:
|
||||
return patch_str
|
||||
|
||||
if should_skip_patch(filename):
|
||||
return patch_str
|
||||
|
||||
try:
|
||||
extended_patch_str = process_patch_lines(patch_str, original_file_str,
|
||||
patch_extra_lines_before, patch_extra_lines_after)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
return patch_str
|
||||
|
||||
return extended_patch_str
|
||||
|
||||
|
||||
def decode_if_bytes(original_file_str):
|
||||
if isinstance(original_file_str, (bytes, bytearray)):
|
||||
try:
|
||||
return original_file_str.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
|
||||
for encoding in encodings_to_try:
|
||||
try:
|
||||
return original_file_str.decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return ""
|
||||
return original_file_str
|
||||
|
||||
|
||||
def should_skip_patch(filename):
|
||||
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
|
||||
if patch_extension_skip_types and filename:
|
||||
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
|
||||
return False
|
||||
|
||||
|
||||
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
|
||||
allow_dynamic_context = get_settings().config.allow_dynamic_context
|
||||
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
|
||||
|
||||
original_lines = original_file_str.splitlines()
|
||||
len_original_lines = len(original_lines)
|
||||
patch_lines = patch_str.splitlines()
|
||||
extended_patch_lines = []
|
||||
|
||||
is_valid_hunk = True
|
||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
try:
|
||||
for i,line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
# identify hunk header
|
||||
if match:
|
||||
# finish processing previous hunk
|
||||
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
|
||||
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
|
||||
extended_patch_lines.extend(delta_lines)
|
||||
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
||||
|
||||
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1)
|
||||
|
||||
if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0):
|
||||
def _calc_context_limits(patch_lines_before):
|
||||
extended_start1 = max(1, start1 - patch_lines_before)
|
||||
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
|
||||
extended_start2 = max(1, start2 - patch_lines_before)
|
||||
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
|
||||
if extended_start1 - 1 + extended_size1 > len_original_lines:
|
||||
# we cannot extend beyond the original file
|
||||
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines
|
||||
extended_size1 = max(extended_size1 - delta_cap, size1)
|
||||
extended_size2 = max(extended_size2 - delta_cap, size2)
|
||||
return extended_start1, extended_size1, extended_start2, extended_size2
|
||||
|
||||
if allow_dynamic_context:
|
||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||
_calc_context_limits(patch_extra_lines_before_dynamic)
|
||||
lines_before = original_lines[extended_start1 - 1:start1 - 1]
|
||||
found_header = False
|
||||
for i, line, in enumerate(lines_before):
|
||||
if section_header in line:
|
||||
found_header = True
|
||||
# Update start and size in one line each
|
||||
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i
|
||||
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i
|
||||
# get_logger().debug(f"Found section header in line {i} before the hunk")
|
||||
section_header = ''
|
||||
break
|
||||
if not found_header:
|
||||
# get_logger().debug(f"Section header not found in the extra lines before the hunk")
|
||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||
_calc_context_limits(patch_extra_lines_before)
|
||||
else:
|
||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||
_calc_context_limits(patch_extra_lines_before)
|
||||
|
||||
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
|
||||
|
||||
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
|
||||
if section_header and not allow_dynamic_context:
|
||||
for line in delta_lines:
|
||||
if section_header in line:
|
||||
section_header = '' # remove section header if it is in the extra delta lines
|
||||
break
|
||||
else:
|
||||
extended_start1 = start1
|
||||
extended_size1 = size1
|
||||
extended_start2 = start2
|
||||
extended_size2 = size2
|
||||
delta_lines = []
|
||||
extended_patch_lines.append('')
|
||||
extended_patch_lines.append(
|
||||
f'@@ -{extended_start1},{extended_size1} '
|
||||
f'+{extended_start2},{extended_size2} @@ {section_header}')
|
||||
extended_patch_lines.extend(delta_lines) # one to zero based
|
||||
continue
|
||||
extended_patch_lines.append(line)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
return patch_str
|
||||
|
||||
# finish processing last hunk
|
||||
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
|
||||
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
|
||||
# add space at the beginning of each extra line
|
||||
delta_lines = [f' {line}' for line in delta_lines]
|
||||
extended_patch_lines.extend(delta_lines)
|
||||
|
||||
extended_patch_str = '\n'.join(extended_patch_lines)
|
||||
return extended_patch_str
|
||||
|
||||
|
||||
def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
|
||||
"""
|
||||
Check if the hunk lines match the original file content. We saw cases where the hunk header line doesn't match the original file content, and then
|
||||
extending the hunk with extra lines before the hunk header can cause the hunk to be invalid.
|
||||
"""
|
||||
is_valid_hunk = True
|
||||
try:
|
||||
if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file
|
||||
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
|
||||
is_valid_hunk = False
|
||||
get_logger().error(
|
||||
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content")
|
||||
except:
|
||||
pass
|
||||
return is_valid_hunk
|
||||
|
||||
|
||||
def extract_hunk_headers(match):
|
||||
res = list(match.groups())
|
||||
for i in range(len(res)):
|
||||
if res[i] is None:
|
||||
res[i] = 0
|
||||
try:
|
||||
start1, size1, start2, size2 = map(int, res[:4])
|
||||
except: # '@@ -0,0 +1 @@' case
|
||||
start1, size1, size2 = map(int, res[:3])
|
||||
start2 = 0
|
||||
section_header = res[4]
|
||||
return section_header, size1, size2, start1, start2
|
||||
|
||||
|
||||
def omit_deletion_hunks(patch_lines) -> str:
|
||||
"""
|
||||
Omit deletion hunks from the patch and return the modified patch.
|
||||
Args:
|
||||
- patch_lines: a list of strings representing the lines of the patch
|
||||
Returns:
|
||||
- A string representing the modified patch with deletion hunks omitted
|
||||
"""
|
||||
|
||||
temp_hunk = []
|
||||
added_patched = []
|
||||
add_hunk = False
|
||||
inside_hunk = False
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
|
||||
|
||||
for line in patch_lines:
|
||||
if line.startswith('@@'):
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
if match:
|
||||
# finish previous hunk
|
||||
if inside_hunk and add_hunk:
|
||||
added_patched.extend(temp_hunk)
|
||||
temp_hunk = []
|
||||
add_hunk = False
|
||||
temp_hunk.append(line)
|
||||
inside_hunk = True
|
||||
else:
|
||||
temp_hunk.append(line)
|
||||
if line:
|
||||
edit_type = line[0]
|
||||
if edit_type == '+':
|
||||
add_hunk = True
|
||||
if inside_hunk and add_hunk:
|
||||
added_patched.extend(temp_hunk)
|
||||
|
||||
return '\n'.join(added_patched)
|
||||
|
||||
|
||||
def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str:
|
||||
"""
|
||||
Handle entire file or deletion patches.
|
||||
|
||||
This function takes a patch, original file content, new file content, and file name as input.
|
||||
It handles entire file or deletion patches and returns the modified patch with deletion hunks omitted.
|
||||
|
||||
Args:
|
||||
patch (str): The patch to be handled.
|
||||
original_file_content_str (str): The original content of the file.
|
||||
new_file_content_str (str): The new content of the file.
|
||||
file_name (str): The name of the file.
|
||||
|
||||
Returns:
|
||||
str: The modified patch with deletion hunks omitted.
|
||||
|
||||
"""
|
||||
if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN):
|
||||
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
||||
if get_settings().config.verbosity_level > 0:
|
||||
get_logger().info(f"Processing file: {file_name}, minimizing deletion file")
|
||||
patch = None # file was deleted
|
||||
else:
|
||||
patch_lines = patch.splitlines()
|
||||
patch_new = omit_deletion_hunks(patch_lines)
|
||||
if patch != patch_new:
|
||||
if get_settings().config.verbosity_level > 0:
|
||||
get_logger().info(f"Processing file: {file_name}, hunks were deleted")
|
||||
patch = patch_new
|
||||
return patch
|
||||
|
||||
|
||||
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
||||
"""
|
||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
|
||||
the file.
|
||||
|
||||
Args:
|
||||
patch (str): The patch string to be converted.
|
||||
file: An object containing the filename of the file being patched.
|
||||
|
||||
Returns:
|
||||
str: A string with line numbers for each hunk, indicating the new and old content of the file.
|
||||
|
||||
example output:
|
||||
## src/file.ts
|
||||
__new hunk__
|
||||
881 line1
|
||||
882 line2
|
||||
883 line3
|
||||
887 + line4
|
||||
888 + line5
|
||||
889 line6
|
||||
890 line7
|
||||
...
|
||||
__old hunk__
|
||||
line1
|
||||
line2
|
||||
- line3
|
||||
- line4
|
||||
line5
|
||||
line6
|
||||
...
|
||||
"""
|
||||
# if the file was deleted, return a message indicating that the file was deleted
|
||||
if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED:
|
||||
return f"\n\n## file '{file.filename.strip()}' was deleted\n"
|
||||
|
||||
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
|
||||
patch_lines = patch.splitlines()
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
new_content_lines = []
|
||||
old_content_lines = []
|
||||
match = None
|
||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||
prev_header_line = []
|
||||
header_line = []
|
||||
for line_i, line in enumerate(patch_lines):
|
||||
if 'no newline at end of file' in line.lower():
|
||||
continue
|
||||
|
||||
if line.startswith('@@'):
|
||||
header_line = line
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines
|
||||
if prev_header_line:
|
||||
patch_with_lines_str += f'\n{prev_header_line}\n'
|
||||
is_plus_lines = is_minus_lines = False
|
||||
if new_content_lines:
|
||||
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
|
||||
if old_content_lines:
|
||||
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
|
||||
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
||||
for i, line_new in enumerate(new_content_lines):
|
||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||
if is_minus_lines:
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
|
||||
for line_old in old_content_lines:
|
||||
patch_with_lines_str += f"{line_old}\n"
|
||||
new_content_lines = []
|
||||
old_content_lines = []
|
||||
if match:
|
||||
prev_header_line = header_line
|
||||
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
||||
|
||||
elif line.startswith('+'):
|
||||
new_content_lines.append(line)
|
||||
elif line.startswith('-'):
|
||||
old_content_lines.append(line)
|
||||
else:
|
||||
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
|
||||
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
|
||||
continue
|
||||
elif line_i + 1 == len(patch_lines):
|
||||
continue
|
||||
new_content_lines.append(line)
|
||||
old_content_lines.append(line)
|
||||
|
||||
# finishing last hunk
|
||||
if match and new_content_lines:
|
||||
patch_with_lines_str += f'\n{header_line}\n'
|
||||
is_plus_lines = is_minus_lines = False
|
||||
if new_content_lines:
|
||||
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
|
||||
if old_content_lines:
|
||||
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
|
||||
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
||||
for i, line_new in enumerate(new_content_lines):
|
||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||
if is_minus_lines:
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
|
||||
for line_old in old_content_lines:
|
||||
patch_with_lines_str += f"{line_old}\n"
|
||||
|
||||
return patch_with_lines_str.rstrip()
|
||||
|
||||
|
||||
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]:
|
||||
try:
|
||||
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
|
||||
selected_lines = ""
|
||||
patch_lines = patch.splitlines()
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
match = None
|
||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||
skip_hunk = False
|
||||
selected_lines_num = 0
|
||||
for line in patch_lines:
|
||||
if 'no newline at end of file' in line.lower():
|
||||
continue
|
||||
|
||||
if line.startswith('@@'):
|
||||
skip_hunk = False
|
||||
selected_lines_num = 0
|
||||
header_line = line
|
||||
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
||||
|
||||
# check if line range is in this hunk
|
||||
if side.lower() == 'left':
|
||||
# check if line range is in this hunk
|
||||
if not (start1 <= line_start <= start1 + size1):
|
||||
skip_hunk = True
|
||||
continue
|
||||
elif side.lower() == 'right':
|
||||
if not (start2 <= line_start <= start2 + size2):
|
||||
skip_hunk = True
|
||||
continue
|
||||
patch_with_lines_str += f'\n{header_line}\n'
|
||||
|
||||
elif not skip_hunk:
|
||||
if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end:
|
||||
selected_lines += line + '\n'
|
||||
if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end:
|
||||
selected_lines += line + '\n'
|
||||
patch_with_lines_str += line + '\n'
|
||||
if not line.startswith('-'): # currently we don't support /ask line for deleted lines
|
||||
selected_lines_num += 1
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
return "", ""
|
||||
|
||||
return patch_with_lines_str.rstrip(), selected_lines.rstrip()
|
||||
70
apps/utils/pr_agent/algo/language_handler.py
Normal file
70
apps/utils/pr_agent/algo/language_handler.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
|
||||
from typing import Dict
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def filter_bad_extensions(files):
|
||||
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
|
||||
bad_extensions = get_settings().bad_extensions.default
|
||||
if get_settings().config.use_extra_bad_extensions:
|
||||
bad_extensions += get_settings().bad_extensions.extra
|
||||
return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)]
|
||||
|
||||
|
||||
def is_valid_file(filename:str, bad_extensions=None) -> bool:
|
||||
if not filename:
|
||||
return False
|
||||
if not bad_extensions:
|
||||
bad_extensions = get_settings().bad_extensions.default
|
||||
if get_settings().config.use_extra_bad_extensions:
|
||||
bad_extensions += get_settings().bad_extensions.extra
|
||||
return filename.split('.')[-1] not in bad_extensions
|
||||
|
||||
|
||||
def sort_files_by_main_languages(languages: Dict, files: list):
|
||||
"""
|
||||
Sort files by their main language, put the files that are in the main language first and the rest files after
|
||||
"""
|
||||
# sort languages by their size
|
||||
languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)]
|
||||
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
|
||||
# get all extensions for the languages
|
||||
main_extensions = []
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||
for language in languages_sorted_list:
|
||||
if language.lower() in language_extension_map:
|
||||
main_extensions.append(language_extension_map[language.lower()])
|
||||
else:
|
||||
main_extensions.append([])
|
||||
|
||||
# filter out files bad extensions
|
||||
files_filtered = filter_bad_extensions(files)
|
||||
# sort files by their extension, put the files that are in the main extension first
|
||||
# and the rest files after, map languages_sorted to their respective files
|
||||
files_sorted = []
|
||||
rest_files = {}
|
||||
|
||||
# if no languages detected, put all files in the "Other" category
|
||||
if not languages:
|
||||
files_sorted = [({"language": "Other", "files": list(files_filtered)})]
|
||||
return files_sorted
|
||||
|
||||
main_extensions_flat = []
|
||||
for ext in main_extensions:
|
||||
main_extensions_flat.extend(ext)
|
||||
|
||||
for extensions, lang in zip(main_extensions, languages_sorted_list): # noqa: B905
|
||||
tmp = []
|
||||
for file in files_filtered:
|
||||
extension_str = f".{file.filename.split('.')[-1]}"
|
||||
if extension_str in extensions:
|
||||
tmp.append(file)
|
||||
else:
|
||||
if (file.filename not in rest_files) and (extension_str not in main_extensions_flat):
|
||||
rest_files[file.filename] = file
|
||||
if len(tmp) > 0:
|
||||
files_sorted.append({"language": lang, "files": tmp})
|
||||
files_sorted.append({"language": "Other", "files": list(rest_files.values())})
|
||||
return files_sorted
|
||||
550
apps/utils/pr_agent/algo/pr_processing.py
Normal file
550
apps/utils/pr_agent/algo/pr_processing.py
Normal file
@ -0,0 +1,550 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from github import RateLimitExceededException
|
||||
|
||||
from utils.pr_agent.algo.file_filter import filter_ignored
|
||||
from utils.pr_agent.algo.git_patch_processing import (
|
||||
convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions)
|
||||
from utils.pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE
|
||||
from utils.pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
DELETED_FILES_ = "Deleted files:\n"
|
||||
|
||||
MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n"
|
||||
|
||||
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
|
||||
|
||||
OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1500
|
||||
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 1000
|
||||
MAX_EXTRA_LINES = 10
|
||||
|
||||
|
||||
def cap_and_log_extra_lines(value, direction) -> int:
|
||||
if value > MAX_EXTRA_LINES:
|
||||
get_logger().warning(f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}")
|
||||
return MAX_EXTRA_LINES
|
||||
return value
|
||||
|
||||
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
model: str,
|
||||
add_line_numbers_to_hunks: bool = False,
|
||||
disable_extra_lines: bool = False,
|
||||
large_pr_handling=False,
|
||||
return_remaining_files=False):
|
||||
if disable_extra_lines:
|
||||
PATCH_EXTRA_LINES_BEFORE = 0
|
||||
PATCH_EXTRA_LINES_AFTER = 0
|
||||
else:
|
||||
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
||||
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
|
||||
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
|
||||
|
||||
try:
|
||||
diff_files_original = git_provider.get_diff_files()
|
||||
except RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||
raise
|
||||
|
||||
diff_files = filter_ignored(diff_files_original)
|
||||
if diff_files != diff_files_original:
|
||||
try:
|
||||
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
|
||||
new_names = set([a.filename for a in diff_files])
|
||||
orig_names = set([a.filename for a in diff_files_original])
|
||||
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
# get pr languages
|
||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||
if pr_languages:
|
||||
try:
|
||||
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# generate a standard diff string, with patch extension
|
||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks,
|
||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
|
||||
f"returning full diff.")
|
||||
return "\n".join(patches_extended)
|
||||
|
||||
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
|
||||
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
||||
f"pruning diff.")
|
||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
|
||||
|
||||
if large_pr_handling and len(patches_compressed_list) > 1:
|
||||
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
|
||||
return "" # return empty string, as we want to generate multiple patches with a different prompt
|
||||
|
||||
# return the first patch
|
||||
patches_compressed = patches_compressed_list[0]
|
||||
total_tokens_new = total_tokens_list[0]
|
||||
files_in_patch = files_in_patches_list[0]
|
||||
|
||||
# Insert additional information about added, modified, and deleted files if there is enough space
|
||||
max_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
|
||||
curr_token = total_tokens_new # == token_handler.count_tokens(final_diff)+token_handler.prompt_tokens
|
||||
final_diff = "\n".join(patches_compressed)
|
||||
delta_tokens = 10
|
||||
added_list_str = modified_list_str = deleted_list_str = ""
|
||||
unprocessed_files = []
|
||||
# generate the added, modified, and deleted files lists
|
||||
if (max_tokens - curr_token) > delta_tokens:
|
||||
for filename, file_values in file_dict.items():
|
||||
if filename in files_in_patch:
|
||||
continue
|
||||
if file_values['edit_type'] == EDIT_TYPE.ADDED:
|
||||
unprocessed_files.append(filename)
|
||||
if not added_list_str:
|
||||
added_list_str = ADDED_FILES_ + f"\n{filename}"
|
||||
else:
|
||||
added_list_str = added_list_str + f"\n{filename}"
|
||||
elif file_values['edit_type'] in [EDIT_TYPE.MODIFIED, EDIT_TYPE.RENAMED]:
|
||||
unprocessed_files.append(filename)
|
||||
if not modified_list_str:
|
||||
modified_list_str = MORE_MODIFIED_FILES_ + f"\n{filename}"
|
||||
else:
|
||||
modified_list_str = modified_list_str + f"\n{filename}"
|
||||
elif file_values['edit_type'] == EDIT_TYPE.DELETED:
|
||||
# unprocessed_files.append(filename) # not needed here, because the file was deleted, so no need to process it
|
||||
if not deleted_list_str:
|
||||
deleted_list_str = DELETED_FILES_ + f"\n{filename}"
|
||||
else:
|
||||
deleted_list_str = deleted_list_str + f"\n{filename}"
|
||||
|
||||
# prune the added, modified, and deleted files lists, and add them to the final diff
|
||||
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
|
||||
if added_list_str:
|
||||
final_diff = final_diff + "\n\n" + added_list_str
|
||||
curr_token += token_handler.count_tokens(added_list_str) + 2
|
||||
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
|
||||
if modified_list_str:
|
||||
final_diff = final_diff + "\n\n" + modified_list_str
|
||||
curr_token += token_handler.count_tokens(modified_list_str) + 2
|
||||
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
|
||||
if deleted_list_str:
|
||||
final_diff = final_diff + "\n\n" + deleted_list_str
|
||||
|
||||
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
||||
f"deleted_list_str: {deleted_list_str}")
|
||||
if not return_remaining_files:
|
||||
return final_diff
|
||||
else:
|
||||
return final_diff, remaining_files_list
|
||||
|
||||
|
||||
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False):
|
||||
try:
|
||||
diff_files_original = git_provider.get_diff_files()
|
||||
except RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||
raise
|
||||
|
||||
diff_files = filter_ignored(diff_files_original)
|
||||
if diff_files != diff_files_original:
|
||||
try:
|
||||
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
|
||||
new_names = set([a.filename for a in diff_files])
|
||||
orig_names = set([a.filename for a in diff_files_original])
|
||||
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# get pr languages
|
||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||
if pr_languages:
|
||||
try:
|
||||
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True)
|
||||
|
||||
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||
|
||||
|
||||
def pr_generate_extended_diff(pr_languages: list,
|
||||
token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool,
|
||||
patch_extra_lines_before: int = 0,
|
||||
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
patches_extended = []
|
||||
patches_extended_tokens = []
|
||||
for lang in pr_languages:
|
||||
for file in lang['files']:
|
||||
original_file_content_str = file.base_file
|
||||
patch = file.patch
|
||||
if not patch:
|
||||
continue
|
||||
|
||||
# extend each patch with extra lines of context
|
||||
extended_patch = extend_patch(original_file_content_str, patch,
|
||||
patch_extra_lines_before, patch_extra_lines_after, file.filename)
|
||||
if not extended_patch:
|
||||
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
|
||||
continue
|
||||
|
||||
if add_line_numbers_to_hunks:
|
||||
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
|
||||
else:
|
||||
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
|
||||
|
||||
# add AI-summary metadata to the patch
|
||||
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
|
||||
full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch)
|
||||
|
||||
patch_tokens = token_handler.count_tokens(full_extended_patch)
|
||||
file.tokens = patch_tokens
|
||||
total_tokens += patch_tokens
|
||||
patches_extended_tokens.append(patch_tokens)
|
||||
patches_extended.append(full_extended_patch)
|
||||
|
||||
return patches_extended, total_tokens, patches_extended_tokens
|
||||
|
||||
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||
convert_hunks_to_line_numbers: bool,
|
||||
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
|
||||
deleted_files_list = []
|
||||
|
||||
# sort each one of the languages in top_langs by the number of tokens in the diff
|
||||
sorted_files = []
|
||||
for lang in top_langs:
|
||||
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
|
||||
|
||||
# generate patches for each file, and count tokens
|
||||
file_dict = {}
|
||||
for file in sorted_files:
|
||||
original_file_content_str = file.base_file
|
||||
new_file_content_str = file.head_file
|
||||
patch = file.patch
|
||||
if not patch:
|
||||
continue
|
||||
|
||||
# removing delete-only hunks
|
||||
patch = handle_patch_deletions(patch, original_file_content_str,
|
||||
new_file_content_str, file.filename, file.edit_type)
|
||||
if patch is None:
|
||||
if file.filename not in deleted_files_list:
|
||||
deleted_files_list.append(file.filename)
|
||||
continue
|
||||
|
||||
if convert_hunks_to_line_numbers:
|
||||
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
||||
|
||||
## add AI-summary metadata to the patch (disabled, since we are in the compressed diff)
|
||||
# if file.ai_file_summary and get_settings().config.get('config.is_auto_command', False):
|
||||
# patch = add_ai_summary_top_patch(file, patch)
|
||||
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type}
|
||||
|
||||
max_tokens_model = get_max_tokens(model)
|
||||
|
||||
# first iteration
|
||||
files_in_patches_list = []
|
||||
remaining_files_list = [file.filename for file in sorted_files]
|
||||
patches_list =[]
|
||||
total_tokens_list = []
|
||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict,
|
||||
max_tokens_model, remaining_files_list, token_handler)
|
||||
patches_list.append(patches)
|
||||
total_tokens_list.append(total_tokens)
|
||||
files_in_patches_list.append(files_in_patch_list)
|
||||
|
||||
# additional iterations (if needed)
|
||||
if large_pr_handling:
|
||||
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
|
||||
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
|
||||
if remaining_files_list:
|
||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
|
||||
file_dict,
|
||||
max_tokens_model,
|
||||
remaining_files_list, token_handler)
|
||||
if patches:
|
||||
patches_list.append(patches)
|
||||
total_tokens_list.append(total_tokens)
|
||||
files_in_patches_list.append(files_in_patch_list)
|
||||
else:
|
||||
break
|
||||
|
||||
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||
|
||||
|
||||
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler):
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
patches = []
|
||||
remaining_files_list_new = []
|
||||
files_in_patch_list = []
|
||||
for filename, data in file_dict.items():
|
||||
if filename not in remaining_files_list_prev:
|
||||
continue
|
||||
|
||||
patch = data['patch']
|
||||
new_patch_tokens = data['tokens']
|
||||
edit_type = data['edit_type']
|
||||
|
||||
# Hard Stop, no more tokens
|
||||
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
get_logger().warning(f"File was fully skipped, no more tokens: {filename}.")
|
||||
continue
|
||||
|
||||
# If the patch is too large, just show the file name
|
||||
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
# Current logic is to skip the patch if it's too large
|
||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||
# until we meet the requirements
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().warning(f"Patch too large, skipping it: '{filename}'")
|
||||
remaining_files_list_new.append(filename)
|
||||
continue
|
||||
|
||||
if patch:
|
||||
if not convert_hunks_to_line_numbers:
|
||||
patch_final = f"\n\n## File: '{filename.strip()}'\n\n{patch.strip()}\n"
|
||||
else:
|
||||
patch_final = "\n\n" + patch.strip()
|
||||
patches.append(patch_final)
|
||||
total_tokens += token_handler.count_tokens(patch_final)
|
||||
files_in_patch_list.append(filename)
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Tokens: {total_tokens}, last filename: {filename}")
|
||||
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
|
||||
all_models = _get_all_models(model_type)
|
||||
all_deployments = _get_all_deployments(all_models)
|
||||
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
||||
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
|
||||
try:
|
||||
get_logger().debug(
|
||||
f"Generating prediction with {model}"
|
||||
f"{(' from deployment ' + deployment_id) if deployment_id else ''}"
|
||||
)
|
||||
get_settings().set("openai.deployment_id", deployment_id)
|
||||
return await f(model)
|
||||
except:
|
||||
get_logger().warning(
|
||||
f"Failed to generate prediction with {model}"
|
||||
)
|
||||
if i == len(all_models) - 1: # If it's the last iteration
|
||||
raise Exception(f"Failed to generate prediction with any model of {all_models}")
|
||||
|
||||
|
||||
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
|
||||
if model_type == ModelType.WEAK:
|
||||
model = get_weak_model()
|
||||
else:
|
||||
model = get_settings().config.model
|
||||
fallback_models = get_settings().config.fallback_models
|
||||
if not isinstance(fallback_models, list):
|
||||
fallback_models = [m.strip() for m in fallback_models.split(",")]
|
||||
all_models = [model] + fallback_models
|
||||
return all_models
|
||||
|
||||
|
||||
def _get_all_deployments(all_models: List[str]) -> List[str]:
|
||||
deployment_id = get_settings().get("openai.deployment_id", None)
|
||||
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
|
||||
if not isinstance(fallback_deployments, list) and fallback_deployments:
|
||||
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
|
||||
if fallback_deployments:
|
||||
all_deployments = [deployment_id] + fallback_deployments
|
||||
if len(all_deployments) < len(all_models):
|
||||
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
|
||||
f"is less than the number of models ({len(all_models)})")
|
||||
else:
|
||||
all_deployments = [deployment_id] * len(all_models)
|
||||
return all_deployments
|
||||
|
||||
|
||||
def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
token_handler: TokenHandler,
|
||||
model: str,
|
||||
max_calls: int = 5) -> List[str]:
|
||||
"""
|
||||
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
|
||||
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
|
||||
|
||||
Args:
|
||||
git_provider (GitProvider): An object that provides access to Git provider APIs.
|
||||
token_handler (TokenHandler): An object that handles tokens in the context of a pull request.
|
||||
model (str): The name of the model.
|
||||
max_calls (int, optional): The maximum number of calls to retrieve diff files. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of final diff strings, split into multiple groups based on the maximum number of tokens allowed for the given model.
|
||||
|
||||
Raises:
|
||||
RateLimitExceededException: If the rate limit for the Git provider API is exceeded.
|
||||
"""
|
||||
try:
|
||||
diff_files = git_provider.get_diff_files()
|
||||
except RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||
raise
|
||||
|
||||
diff_files = filter_ignored(diff_files)
|
||||
|
||||
# Sort files by main language
|
||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||
|
||||
# Sort files within each language group by tokens in descending order
|
||||
sorted_files = []
|
||||
for lang in pr_languages:
|
||||
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
|
||||
|
||||
# Get the maximum number of extra lines before and after the patch
|
||||
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
||||
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
|
||||
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
|
||||
|
||||
# try first a single run with standard diff string, with patch extension, and no deletions
|
||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks=True,
|
||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
|
||||
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||
return ["\n".join(patches_extended)] if patches_extended else []
|
||||
|
||||
patches = []
|
||||
final_diff_list = []
|
||||
total_tokens = token_handler.prompt_tokens
|
||||
call_number = 1
|
||||
for file in sorted_files:
|
||||
if call_number > max_calls:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Reached max calls ({max_calls})")
|
||||
break
|
||||
|
||||
original_file_content_str = file.base_file
|
||||
new_file_content_str = file.head_file
|
||||
patch = file.patch
|
||||
if not patch:
|
||||
continue
|
||||
|
||||
# Remove delete-only hunks
|
||||
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type)
|
||||
if patch is None:
|
||||
continue
|
||||
|
||||
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
||||
# add AI-summary metadata to the patch
|
||||
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
|
||||
patch = add_ai_summary_top_patch(file, patch)
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
|
||||
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
elif get_settings().config.get('large_patch_policy') == 'clip':
|
||||
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
|
||||
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
|
||||
new_patch_tokens = token_handler.count_tokens(patch_clipped)
|
||||
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
else:
|
||||
get_logger().info(f"Clipped large patch for file: {file.filename}")
|
||||
patch = patch_clipped
|
||||
else:
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
|
||||
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
|
||||
final_diff = "\n".join(patches)
|
||||
final_diff_list.append(final_diff)
|
||||
patches = []
|
||||
total_tokens = token_handler.prompt_tokens
|
||||
call_number += 1
|
||||
if call_number > max_calls: # avoid creating new patches
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Reached max calls ({max_calls})")
|
||||
break
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Call number: {call_number}")
|
||||
|
||||
if patch:
|
||||
patches.append(patch)
|
||||
total_tokens += new_patch_tokens
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
||||
|
||||
# Add the last chunk
|
||||
if patches:
|
||||
final_diff = "\n".join(patches)
|
||||
final_diff_list.append(final_diff)
|
||||
|
||||
return final_diff_list
|
||||
|
||||
|
||||
def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
|
||||
"""
|
||||
Adds AI metadata to the diff files based on the PR description files (FilePatchInfo.ai_file_summary).
|
||||
"""
|
||||
try:
|
||||
if not pr_description_files:
|
||||
get_logger().warning(f"PR description files are empty.")
|
||||
return
|
||||
available_files = {pr_file['full_file_name'].strip(): pr_file for pr_file in pr_description_files}
|
||||
diff_files = git_provider.get_diff_files()
|
||||
found_any_match = False
|
||||
for file in diff_files:
|
||||
filename = file.filename.strip()
|
||||
if filename in available_files:
|
||||
file.ai_file_summary = available_files[filename]
|
||||
found_any_match = True
|
||||
if not found_any_match:
|
||||
get_logger().error(f"Failed to find any matching files between PR description and diff files.",
|
||||
artifact={"pr_description_files": pr_description_files})
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to add AI metadata to diff files: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
|
||||
|
||||
def add_ai_summary_top_patch(file, full_extended_patch):
|
||||
try:
|
||||
# below every instance of '## File: ...' in the patch, add the ai-summary metadata
|
||||
full_extended_patch_lines = full_extended_patch.split("\n")
|
||||
for i, line in enumerate(full_extended_patch_lines):
|
||||
if line.startswith("## File:") or line.startswith("## file:"):
|
||||
full_extended_patch_lines.insert(i + 1,
|
||||
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}")
|
||||
full_extended_patch = "\n".join(full_extended_patch_lines)
|
||||
return full_extended_patch
|
||||
|
||||
# if no '## File: ...' was found
|
||||
return full_extended_patch
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to add AI summary to the top of the patch: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
return full_extended_patch
|
||||
89
apps/utils/pr_agent/algo/token_handler.py
Normal file
89
apps/utils/pr_agent/algo/token_handler.py
Normal file
@ -0,0 +1,89 @@
|
||||
from threading import Lock
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
from tiktoken import encoding_for_model, get_encoding
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
class TokenEncoder:
|
||||
_encoder_instance = None
|
||||
_model = None
|
||||
_lock = Lock() # Create a lock object
|
||||
|
||||
@classmethod
|
||||
def get_token_encoder(cls):
|
||||
model = get_settings().config.model
|
||||
if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance
|
||||
with cls._lock: # Lock acquisition to ensure thread safety
|
||||
if cls._encoder_instance is None or model != cls._model:
|
||||
cls._model = model
|
||||
cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding(
|
||||
"cl100k_base")
|
||||
return cls._encoder_instance
|
||||
|
||||
|
||||
class TokenHandler:
|
||||
"""
|
||||
A class for handling tokens in the context of a pull request.
|
||||
|
||||
Attributes:
|
||||
- encoder: An object of the encoding_for_model class from the tiktoken module. Used to encode strings and count the
|
||||
number of tokens in them.
|
||||
- limit: The maximum number of tokens allowed for the given model, as defined in the MAX_TOKENS dictionary in the
|
||||
pr_agent.algo module.
|
||||
- prompt_tokens: The number of tokens in the system and user strings, as calculated by the _get_system_user_tokens
|
||||
method.
|
||||
"""
|
||||
|
||||
def __init__(self, pr=None, vars: dict = {}, system="", user=""):
|
||||
"""
|
||||
Initializes the TokenHandler object.
|
||||
|
||||
Args:
|
||||
- pr: The pull request object.
|
||||
- vars: A dictionary of variables.
|
||||
- system: The system string.
|
||||
- user: The user string.
|
||||
"""
|
||||
self.encoder = TokenEncoder.get_token_encoder()
|
||||
if pr is not None:
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
"""
|
||||
Calculates the number of tokens in the system and user strings.
|
||||
|
||||
Args:
|
||||
- pr: The pull request object.
|
||||
- encoder: An object of the encoding_for_model class from the tiktoken module.
|
||||
- vars: A dictionary of variables.
|
||||
- system: The system string.
|
||||
- user: The user string.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens in the system and user strings.
|
||||
"""
|
||||
try:
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(system).render(vars)
|
||||
user_prompt = environment.from_string(user).render(vars)
|
||||
system_prompt_tokens = len(encoder.encode(system_prompt))
|
||||
user_prompt_tokens = len(encoder.encode(user_prompt))
|
||||
return system_prompt_tokens + user_prompt_tokens
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error in _get_system_user_tokens: {e}")
|
||||
return 0
|
||||
|
||||
def count_tokens(self, patch: str) -> int:
|
||||
"""
|
||||
Counts the number of tokens in a given patch string.
|
||||
|
||||
Args:
|
||||
- patch: The patch string.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the patch string.
|
||||
"""
|
||||
return len(self.encoder.encode(patch, disallowed_special=()))
|
||||
26
apps/utils/pr_agent/algo/types.py
Normal file
26
apps/utils/pr_agent/algo/types.py
Normal file
@ -0,0 +1,26 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class EDIT_TYPE(Enum):
|
||||
ADDED = 1
|
||||
DELETED = 2
|
||||
MODIFIED = 3
|
||||
RENAMED = 4
|
||||
UNKNOWN = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePatchInfo:
|
||||
base_file: str
|
||||
head_file: str
|
||||
patch: str
|
||||
filename: str
|
||||
tokens: int = -1
|
||||
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN
|
||||
old_filename: str = None
|
||||
num_plus_lines: int = -1
|
||||
num_minus_lines: int = -1
|
||||
language: Optional[str] = None
|
||||
ai_file_summary: str = None
|
||||
1247
apps/utils/pr_agent/algo/utils.py
Normal file
1247
apps/utils/pr_agent/algo/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
96
apps/utils/pr_agent/cli.py
Normal file
96
apps/utils/pr_agent/cli.py
Normal file
@ -0,0 +1,96 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent, commands
|
||||
from utils.pr_agent.algo.utils import get_version
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger, setup_logger
|
||||
|
||||
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
||||
setup_logger(log_level)
|
||||
|
||||
|
||||
def set_parser():
|
||||
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
|
||||
"""\
|
||||
Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>].
|
||||
For example:
|
||||
- cli.py --pr_url=... review
|
||||
- cli.py --pr_url=... describe
|
||||
- cli.py --pr_url=... improve
|
||||
- cli.py --pr_url=... ask "write me a poem about this PR"
|
||||
- cli.py --pr_url=... reflect
|
||||
- cli.py --issue_url=... similar_issue
|
||||
|
||||
Supported commands:
|
||||
- review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
|
||||
|
||||
- ask / ask_question [question] - Ask a question about the PR.
|
||||
|
||||
- describe / describe_pr - Modify the PR title and description based on the PR's contents.
|
||||
|
||||
- improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
|
||||
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback
|
||||
|
||||
- reflect - Ask the PR author questions about the PR.
|
||||
|
||||
- update_changelog - Update the changelog based on the PR's contents.
|
||||
|
||||
- add_docs
|
||||
|
||||
- generate_labels
|
||||
|
||||
|
||||
Configuration:
|
||||
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
|
||||
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
||||
""")
|
||||
parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}')
|
||||
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None)
|
||||
parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None)
|
||||
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
|
||||
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
||||
return parser
|
||||
|
||||
|
||||
def run_command(pr_url, command):
|
||||
# Preparing the command
|
||||
run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}"
|
||||
args = set_parser().parse_args(run_command_str.split())
|
||||
|
||||
# Run the command. Feedback will appear in GitHub PR comments
|
||||
run(args=args)
|
||||
|
||||
|
||||
def run(inargs=None, args=None):
|
||||
parser = set_parser()
|
||||
if not args:
|
||||
args = parser.parse_args(inargs)
|
||||
if not args.pr_url and not args.issue_url:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
command = args.command.lower()
|
||||
get_settings().set("CONFIG.CLI_MODE", True)
|
||||
|
||||
async def inner():
|
||||
if args.issue_url:
|
||||
result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest))
|
||||
else:
|
||||
result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest))
|
||||
|
||||
if get_settings().litellm.get("enable_callbacks", False):
|
||||
# There may be additional events on the event queue from the run above. If there are give them time to complete.
|
||||
get_logger().debug("Waiting for event queue to complete")
|
||||
await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()])
|
||||
|
||||
return result
|
||||
|
||||
result = asyncio.run(inner())
|
||||
if not result:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
23
apps/utils/pr_agent/cli_pip.py
Normal file
23
apps/utils/pr_agent/cli_pip.py
Normal file
@ -0,0 +1,23 @@
|
||||
from utils.pr_agent import cli
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def main():
|
||||
# Fill in the following values
|
||||
provider = "github" # GitHub provider
|
||||
user_token = "..." # GitHub user token
|
||||
openai_key = "..." # OpenAI key
|
||||
pr_url = "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
|
||||
command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"')
|
||||
|
||||
# Setting the configurations
|
||||
get_settings().set("CONFIG.git_provider", provider)
|
||||
get_settings().set("openai.key", openai_key)
|
||||
get_settings().set("github.user_token", user_token)
|
||||
|
||||
# Run the command. Feedback will appear in GitHub PR comments
|
||||
cli.run_command(pr_url, command)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
81
apps/utils/pr_agent/config_loader.py
Normal file
81
apps/utils/pr_agent/config_loader.py
Normal file
@ -0,0 +1,81 @@
|
||||
from os.path import abspath, dirname, join
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from starlette_context import context
|
||||
|
||||
PR_AGENT_TOML_KEY = 'pr-agent'
|
||||
|
||||
current_dir = dirname(abspath(__file__))
|
||||
global_settings = Dynaconf(
|
||||
envvar_prefix=False,
|
||||
merge_enabled=True,
|
||||
settings_files=[join(current_dir, f) for f in [
|
||||
"settings/configuration.toml",
|
||||
"settings/ignore.toml",
|
||||
"settings/language_extensions.toml",
|
||||
"settings/pr_reviewer_prompts.toml",
|
||||
"settings/pr_questions_prompts.toml",
|
||||
"settings/pr_line_questions_prompts.toml",
|
||||
"settings/pr_description_prompts.toml",
|
||||
"settings/pr_code_suggestions_prompts.toml",
|
||||
"settings/pr_code_suggestions_reflect_prompts.toml",
|
||||
"settings/pr_sort_code_suggestions_prompts.toml",
|
||||
"settings/pr_information_from_user_prompts.toml",
|
||||
"settings/pr_update_changelog_prompts.toml",
|
||||
"settings/pr_custom_labels.toml",
|
||||
"settings/pr_add_docs.toml",
|
||||
"settings/custom_labels.toml",
|
||||
"settings/pr_help_prompts.toml",
|
||||
"settings/.secrets.toml",
|
||||
"settings_prod/.secrets.toml",
|
||||
]]
|
||||
)
|
||||
|
||||
|
||||
def get_settings():
|
||||
"""
|
||||
Retrieves the current settings.
|
||||
|
||||
This function attempts to fetch the settings from the starlette_context's context object. If it fails,
|
||||
it defaults to the global settings defined outside of this function.
|
||||
|
||||
Returns:
|
||||
Dynaconf: The current settings object, either from the context or the global default.
|
||||
"""
|
||||
try:
|
||||
return context["settings"]
|
||||
except Exception:
|
||||
return global_settings
|
||||
|
||||
|
||||
# Add local configuration from pyproject.toml of the project being reviewed
|
||||
def _find_repository_root() -> Optional[Path]:
|
||||
"""
|
||||
Identify project root directory by recursively searching for the .git directory in the parent directories.
|
||||
"""
|
||||
cwd = Path.cwd().resolve()
|
||||
no_way_up = False
|
||||
while not no_way_up:
|
||||
no_way_up = cwd == cwd.parent
|
||||
if (cwd / ".git").is_dir():
|
||||
return cwd
|
||||
cwd = cwd.parent
|
||||
return None
|
||||
|
||||
|
||||
def _find_pyproject() -> Optional[Path]:
|
||||
"""
|
||||
Search for file pyproject.toml in the repository root.
|
||||
"""
|
||||
repo_root = _find_repository_root()
|
||||
if repo_root:
|
||||
pyproject = repo_root / "pyproject.toml"
|
||||
return pyproject if pyproject.is_file() else None
|
||||
return None
|
||||
|
||||
|
||||
pyproject_path = _find_pyproject()
|
||||
if pyproject_path is not None:
|
||||
get_settings().load_file(pyproject_path, env=f'tool.{PR_AGENT_TOML_KEY}')
|
||||
64
apps/utils/pr_agent/git_providers/__init__.py
Normal file
64
apps/utils/pr_agent/git_providers/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
from starlette_context import context
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
||||
from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
||||
from utils.pr_agent.git_providers.bitbucket_server_provider import \
|
||||
BitbucketServerProvider
|
||||
from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
||||
from utils.pr_agent.git_providers.gerrit_provider import GerritProvider
|
||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||
from utils.pr_agent.git_providers.github_provider import GithubProvider
|
||||
from utils.pr_agent.git_providers.gitlab_provider import GitLabProvider
|
||||
from utils.pr_agent.git_providers.local_git_provider import LocalGitProvider
|
||||
|
||||
_GIT_PROVIDERS = {
|
||||
'github': GithubProvider,
|
||||
'gitlab': GitLabProvider,
|
||||
'bitbucket': BitbucketProvider,
|
||||
'bitbucket_server': BitbucketServerProvider,
|
||||
'azure': AzureDevopsProvider,
|
||||
'codecommit': CodeCommitProvider,
|
||||
'local': LocalGitProvider,
|
||||
'gerrit': GerritProvider,
|
||||
}
|
||||
|
||||
|
||||
def get_git_provider():
|
||||
try:
|
||||
provider_id = get_settings().config.git_provider
|
||||
except AttributeError as e:
|
||||
raise ValueError("git_provider is a required attribute in the configuration file") from e
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
return _GIT_PROVIDERS[provider_id]
|
||||
|
||||
|
||||
def get_git_provider_with_context(pr_url) -> GitProvider:
|
||||
"""
|
||||
Get a GitProvider instance for the given PR URL. If the GitProvider instance is already in the context, return it.
|
||||
"""
|
||||
|
||||
is_context_env = None
|
||||
try:
|
||||
is_context_env = context.get("settings", None)
|
||||
except Exception:
|
||||
pass # we are not in a context environment (CLI)
|
||||
|
||||
# check if context["git_provider"]["pr_url"] exists
|
||||
if is_context_env and context.get("git_provider", {}).get("pr_url", {}):
|
||||
git_provider = context["git_provider"]["pr_url"]
|
||||
# possibly check if the git_provider is still valid, or if some reset is needed
|
||||
# ...
|
||||
return git_provider
|
||||
else:
|
||||
try:
|
||||
provider_id = get_settings().config.git_provider
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
git_provider = _GIT_PROVIDERS[provider_id](pr_url)
|
||||
if is_context_env:
|
||||
context["git_provider"] = {pr_url: git_provider}
|
||||
return git_provider
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get git provider for {pr_url}") from e
|
||||
620
apps/utils/pr_agent/git_providers/azuredevops_provider.py
Normal file
620
apps/utils/pr_agent/git_providers/azuredevops_provider.py
Normal file
@ -0,0 +1,620 @@
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
from ..algo.file_filter import filter_ignored
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import (PRDescriptionHeader, find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff)
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import GitProvider
|
||||
|
||||
AZURE_DEVOPS_AVAILABLE = True
|
||||
ADO_APP_CLIENT_DEFAULT_ID = "499b84ac-1321-427f-aa17-267ca6975798/.default"
|
||||
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000-1
|
||||
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
# noinspection PyUnresolvedReferences
|
||||
from azure.devops.connection import Connection
|
||||
# noinspection PyUnresolvedReferences
|
||||
from azure.devops.v7_1.git.models import (Comment, CommentThread,
|
||||
GitPullRequest,
|
||||
GitPullRequestIterationChanges,
|
||||
GitVersionDescriptor)
|
||||
# noinspection PyUnresolvedReferences
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from msrest.authentication import BasicAuthentication
|
||||
except ImportError:
|
||||
AZURE_DEVOPS_AVAILABLE = False
|
||||
|
||||
|
||||
class AzureDevopsProvider(GitProvider):
|
||||
|
||||
def __init__(
|
||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
||||
):
|
||||
if not AZURE_DEVOPS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Azure DevOps provider is not available. Please install the required dependencies."
|
||||
)
|
||||
|
||||
self.azure_devops_client = self._get_azure_devops_client()
|
||||
self.diff_files = None
|
||||
self.workspace_slug = None
|
||||
self.repo_slug = None
|
||||
self.repo = None
|
||||
self.pr_num = None
|
||||
self.pr = None
|
||||
self.temp_comments = []
|
||||
self.incremental = incremental
|
||||
if pr_url:
|
||||
self.set_pr(pr_url)
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
"""
|
||||
Publishes code suggestions as comments on the PR.
|
||||
"""
|
||||
post_parameters_list = []
|
||||
for suggestion in code_suggestions:
|
||||
body = suggestion['body']
|
||||
relevant_file = suggestion['relevant_file']
|
||||
relevant_lines_start = suggestion['relevant_lines_start']
|
||||
relevant_lines_end = suggestion['relevant_lines_end']
|
||||
|
||||
if not relevant_lines_start or relevant_lines_start == -1:
|
||||
get_logger().warning(
|
||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
|
||||
continue
|
||||
|
||||
if relevant_lines_end < relevant_lines_start:
|
||||
get_logger().warning(f"Failed to publish code suggestion, "
|
||||
f"relevant_lines_end is {relevant_lines_end} and "
|
||||
f"relevant_lines_start is {relevant_lines_start}")
|
||||
continue
|
||||
|
||||
if relevant_lines_end > relevant_lines_start:
|
||||
post_parameters = {
|
||||
"body": body,
|
||||
"path": relevant_file,
|
||||
"line": relevant_lines_end,
|
||||
"start_line": relevant_lines_start,
|
||||
"start_side": "RIGHT",
|
||||
}
|
||||
else: # API is different for single line comments
|
||||
post_parameters = {
|
||||
"body": body,
|
||||
"path": relevant_file,
|
||||
"line": relevant_lines_start,
|
||||
"side": "RIGHT",
|
||||
}
|
||||
post_parameters_list.append(post_parameters)
|
||||
if not post_parameters_list:
|
||||
return False
|
||||
|
||||
for post_parameters in post_parameters_list:
|
||||
try:
|
||||
comment = Comment(content=post_parameters["body"], comment_type=1)
|
||||
thread = CommentThread(comments=[comment],
|
||||
thread_context={
|
||||
"filePath": post_parameters["path"],
|
||||
"rightFileStart": {
|
||||
"line": post_parameters["start_line"],
|
||||
"offset": 1,
|
||||
},
|
||||
"rightFileEnd": {
|
||||
"line": post_parameters["line"],
|
||||
"offset": 1,
|
||||
},
|
||||
})
|
||||
self.azure_devops_client.create_thread(
|
||||
comment_thread=thread,
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Azure failed to publish code suggestion, error: {e}")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def get_pr_description_full(self) -> str:
|
||||
return self.pr.description
|
||||
|
||||
def edit_comment(self, comment, body: str):
|
||||
try:
|
||||
self.azure_devops_client.update_comment(
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
thread_id=comment["thread_id"],
|
||||
comment_id=comment["comment_id"],
|
||||
comment=Comment(content=body),
|
||||
project=self.workspace_slug,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to edit comment, error: {e}")
|
||||
|
||||
def remove_comment(self, comment):
|
||||
try:
|
||||
self.azure_devops_client.delete_comment(
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
thread_id=comment["thread_id"],
|
||||
comment_id=comment["comment_id"],
|
||||
project=self.workspace_slug,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove comment, error: {e}")
|
||||
|
||||
def publish_labels(self, pr_types):
|
||||
try:
|
||||
for pr_type in pr_types:
|
||||
self.azure_devops_client.create_pull_request_label(
|
||||
label={"name": pr_type},
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to publish labels, error: {e}")
|
||||
|
||||
def get_pr_labels(self, update=False):
|
||||
try:
|
||||
labels = self.azure_devops_client.get_pull_request_labels(
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
)
|
||||
return [label.name for label in labels]
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to get labels, error: {e}")
|
||||
return []
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in [
|
||||
"get_issue_comments",
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_pr(self, pr_url: str):
|
||||
self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url)
|
||||
self.pr = self._get_pr()
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
contents = self.azure_devops_client.get_item_content(
|
||||
repository_id=self.repo_slug,
|
||||
project=self.workspace_slug,
|
||||
download=False,
|
||||
include_content_metadata=False,
|
||||
include_content=True,
|
||||
path=".pr_agent.toml",
|
||||
)
|
||||
return list(contents)[0]
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().error(f"Failed to get repo settings, error: {e}")
|
||||
return ""
|
||||
|
||||
def get_files(self):
|
||||
files = []
|
||||
for i in self.azure_devops_client.get_pull_request_commits(
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
):
|
||||
changes_obj = self.azure_devops_client.get_changes(
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
commit_id=i.commit_id,
|
||||
)
|
||||
|
||||
for c in changes_obj.changes:
|
||||
files.append(c["item"]["path"])
|
||||
return list(set(files))
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
try:
|
||||
|
||||
if self.diff_files:
|
||||
return self.diff_files
|
||||
|
||||
base_sha = self.pr.last_merge_target_commit
|
||||
head_sha = self.pr.last_merge_source_commit
|
||||
|
||||
# Get PR iterations
|
||||
iterations = self.azure_devops_client.get_pull_request_iterations(
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
project=self.workspace_slug
|
||||
)
|
||||
changes = None
|
||||
if iterations:
|
||||
iteration_id = iterations[-1].id # Get the last iteration (most recent changes)
|
||||
|
||||
# Get changes for the iteration
|
||||
changes = self.azure_devops_client.get_pull_request_iteration_changes(
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
iteration_id=iteration_id,
|
||||
project=self.workspace_slug
|
||||
)
|
||||
diff_files = []
|
||||
diffs = []
|
||||
diff_types = {}
|
||||
if changes:
|
||||
for change in changes.change_entries:
|
||||
item = change.additional_properties.get('item', {})
|
||||
path = item.get('path', None)
|
||||
if path:
|
||||
diffs.append(path)
|
||||
diff_types[path] = change.additional_properties.get('changeType', 'Unknown')
|
||||
|
||||
# wrong implementation - gets all the files that were changed in any commit in the PR
|
||||
# commits = self.azure_devops_client.get_pull_request_commits(
|
||||
# project=self.workspace_slug,
|
||||
# repository_id=self.repo_slug,
|
||||
# pull_request_id=self.pr_num,
|
||||
# )
|
||||
#
|
||||
# diff_files = []
|
||||
# diffs = []
|
||||
# diff_types = {}
|
||||
|
||||
# for c in commits:
|
||||
# changes_obj = self.azure_devops_client.get_changes(
|
||||
# project=self.workspace_slug,
|
||||
# repository_id=self.repo_slug,
|
||||
# commit_id=c.commit_id,
|
||||
# )
|
||||
# for i in changes_obj.changes:
|
||||
# if i["item"]["gitObjectType"] == "tree":
|
||||
# continue
|
||||
# diffs.append(i["item"]["path"])
|
||||
# diff_types[i["item"]["path"]] = i["changeType"]
|
||||
#
|
||||
# diffs = list(set(diffs))
|
||||
|
||||
diffs_original = diffs
|
||||
diffs = filter_ignored(diffs_original, 'azure')
|
||||
if diffs_original != diffs:
|
||||
try:
|
||||
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
|
||||
{"files": diffs_original, # diffs is just a list of names
|
||||
"filtered_files": diffs})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
invalid_files_names = []
|
||||
for file in diffs:
|
||||
if not is_valid_file(file):
|
||||
invalid_files_names.append(file)
|
||||
continue
|
||||
|
||||
version = GitVersionDescriptor(
|
||||
version=head_sha.commit_id, version_type="commit"
|
||||
)
|
||||
try:
|
||||
new_file_content_str = self.azure_devops_client.get_item(
|
||||
repository_id=self.repo_slug,
|
||||
path=file,
|
||||
project=self.workspace_slug,
|
||||
version_descriptor=version,
|
||||
download=False,
|
||||
include_content=True,
|
||||
)
|
||||
|
||||
new_file_content_str = new_file_content_str.content
|
||||
except Exception as error:
|
||||
get_logger().error(f"Failed to retrieve new file content of {file} at version {version}", error=error)
|
||||
# get_logger().error(
|
||||
# "Failed to retrieve new file content of %s at version %s. Error: %s",
|
||||
# file,
|
||||
# version,
|
||||
# str(error),
|
||||
# )
|
||||
new_file_content_str = ""
|
||||
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
if diff_types[file] == "add":
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif diff_types[file] == "delete":
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif "rename" in diff_types[file]: # diff_type can be `rename` | `edit, rename`
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
|
||||
version = GitVersionDescriptor(
|
||||
version=base_sha.commit_id, version_type="commit"
|
||||
)
|
||||
if edit_type == EDIT_TYPE.ADDED or edit_type == EDIT_TYPE.RENAMED:
|
||||
original_file_content_str = ""
|
||||
else:
|
||||
try:
|
||||
original_file_content_str = self.azure_devops_client.get_item(
|
||||
repository_id=self.repo_slug,
|
||||
path=file,
|
||||
project=self.workspace_slug,
|
||||
version_descriptor=version,
|
||||
download=False,
|
||||
include_content=True,
|
||||
)
|
||||
original_file_content_str = original_file_content_str.content
|
||||
except Exception as error:
|
||||
get_logger().error(f"Failed to retrieve original file content of {file} at version {version}", error=error)
|
||||
original_file_content_str = ""
|
||||
|
||||
patch = load_large_diff(
|
||||
file, new_file_content_str, original_file_content_str, show_warning=False
|
||||
).rstrip()
|
||||
|
||||
# count number of lines added and removed
|
||||
patch_lines = patch.splitlines(keepends=True)
|
||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
||||
|
||||
diff_files.append(
|
||||
FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
patch=patch,
|
||||
filename=file,
|
||||
edit_type=edit_type,
|
||||
num_plus_lines=num_plus_lines,
|
||||
num_minus_lines=num_minus_lines,
|
||||
)
|
||||
)
|
||||
get_logger().info(f"Invalid files: {invalid_files_names}")
|
||||
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to get diff files, error: {e}")
|
||||
return []
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False, thread_context=None):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
||||
return None
|
||||
comment = Comment(content=pr_comment)
|
||||
thread = CommentThread(comments=[comment], thread_context=thread_context, status=5)
|
||||
thread_response = self.azure_devops_client.create_thread(
|
||||
comment_thread=thread,
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
)
|
||||
response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id}
|
||||
if is_temporary:
|
||||
self.temp_comments.append(response)
|
||||
return response
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
||||
|
||||
usage_guide_text='<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
|
||||
ind = pr_body.find(usage_guide_text)
|
||||
if ind != -1:
|
||||
pr_body = pr_body[:ind]
|
||||
|
||||
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
||||
changes_walkthrough_text = PRDescriptionHeader.CHANGES_WALKTHROUGH.value
|
||||
ind = pr_body.find(changes_walkthrough_text)
|
||||
if ind != -1:
|
||||
pr_body = pr_body[:ind]
|
||||
|
||||
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
||||
trunction_message = " ... (description truncated due to length limit)"
|
||||
pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message
|
||||
get_logger().warning("PR description was truncated due to length limit")
|
||||
try:
|
||||
updated_pr = GitPullRequest()
|
||||
updated_pr.title = pr_title
|
||||
updated_pr.description = pr_body
|
||||
self.azure_devops_client.update_pull_request(
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
git_pull_request_to_update=updated_pr,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(
|
||||
f"Could not update pull request {self.pr_num} description: {e}"
|
||||
)
|
||||
|
||||
def remove_initial_comment(self):
|
||||
try:
|
||||
for comment in self.temp_comments:
|
||||
self.remove_comment(comment)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove temp comments, error: {e}")
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
|
||||
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position)
|
||||
if position == -1:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
subject_type = "FILE"
|
||||
else:
|
||||
subject_type = "LINE"
|
||||
path = relevant_file.strip()
|
||||
return dict(body=body, path=path, position=position, absolute_position=absolute_position) if subject_type == "LINE" else {}
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
|
||||
overall_success = True
|
||||
for comment in comments:
|
||||
try:
|
||||
self.publish_comment(comment["body"],
|
||||
thread_context={
|
||||
"filePath": comment["path"],
|
||||
"rightFileStart": {
|
||||
"line": comment["absolute_position"],
|
||||
"offset": comment["position"],
|
||||
},
|
||||
"rightFileEnd": {
|
||||
"line": comment["absolute_position"],
|
||||
"offset": comment["position"],
|
||||
},
|
||||
})
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(
|
||||
f"Published code suggestion on {self.pr_num} at {comment['path']}"
|
||||
)
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().error(f"Failed to publish code suggestion, error: {e}")
|
||||
overall_success = False
|
||||
return overall_success
|
||||
|
||||
def get_title(self):
|
||||
return self.pr.title
|
||||
|
||||
def get_languages(self):
|
||||
languages = []
|
||||
files = self.azure_devops_client.get_items(
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
recursion_level="Full",
|
||||
include_content_metadata=True,
|
||||
include_links=False,
|
||||
download=False,
|
||||
)
|
||||
for f in files:
|
||||
if f.git_object_type == "blob":
|
||||
file_name, file_extension = os.path.splitext(f.path)
|
||||
languages.append(file_extension[1:])
|
||||
|
||||
extension_counts = {}
|
||||
for ext in languages:
|
||||
if ext != "":
|
||||
extension_counts[ext] = extension_counts.get(ext, 0) + 1
|
||||
|
||||
total_extensions = sum(extension_counts.values())
|
||||
|
||||
extension_percentages = {
|
||||
ext: (count / total_extensions) * 100
|
||||
for ext, count in extension_counts.items()
|
||||
}
|
||||
|
||||
return extension_percentages
|
||||
|
||||
def get_pr_branch(self):
|
||||
pr_info = self.azure_devops_client.get_pull_request_by_id(
|
||||
project=self.workspace_slug, pull_request_id=self.pr_num
|
||||
)
|
||||
source_branch = pr_info.source_ref_name.split("/")[-1]
|
||||
return source_branch
|
||||
|
||||
def get_user_id(self):
|
||||
return 0
|
||||
|
||||
def get_issue_comments(self):
|
||||
threads = self.azure_devops_client.get_threads(repository_id=self.repo_slug, pull_request_id=self.pr_num, project=self.workspace_slug)
|
||||
threads.reverse()
|
||||
comment_list = []
|
||||
for thread in threads:
|
||||
for comment in thread.comments:
|
||||
if comment.content and comment not in comment_list:
|
||||
comment.body = comment.content
|
||||
comment.thread_id = thread.id
|
||||
comment_list.append(comment)
|
||||
return comment_list
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _parse_pr_url(pr_url: str) -> Tuple[str, str, int]:
|
||||
parsed_url = urlparse(pr_url)
|
||||
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
if "pullrequest" not in path_parts:
|
||||
raise ValueError(
|
||||
"The provided URL does not appear to be a Azure DevOps PR URL"
|
||||
)
|
||||
if len(path_parts) == 6: # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
|
||||
workspace_slug = path_parts[1]
|
||||
repo_slug = path_parts[3]
|
||||
pr_number = int(path_parts[5])
|
||||
elif len(path_parts) == 5: # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
|
||||
workspace_slug = path_parts[0]
|
||||
repo_slug = path_parts[2]
|
||||
pr_number = int(path_parts[4])
|
||||
else:
|
||||
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL")
|
||||
|
||||
return workspace_slug, repo_slug, pr_number
|
||||
|
||||
@staticmethod
|
||||
def _get_azure_devops_client():
|
||||
org = get_settings().azure_devops.get("org", None)
|
||||
pat = get_settings().azure_devops.get("pat", None)
|
||||
|
||||
if not org:
|
||||
raise ValueError("Azure DevOps organization is required")
|
||||
|
||||
if pat:
|
||||
auth_token = pat
|
||||
else:
|
||||
try:
|
||||
# try to use azure default credentials
|
||||
# see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
|
||||
# for usage and env var configuration of user-assigned managed identity, local machine auth etc.
|
||||
get_logger().info("No PAT found in settings, trying to use Azure Default Credentials.")
|
||||
credentials = DefaultAzureCredential()
|
||||
accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID)
|
||||
auth_token = accessToken.token
|
||||
except Exception as e:
|
||||
get_logger().error(f"No PAT found in settings, and Azure Default Authentication failed, error: {e}")
|
||||
raise
|
||||
|
||||
credentials = BasicAuthentication("", auth_token)
|
||||
|
||||
credentials = BasicAuthentication("", auth_token)
|
||||
azure_devops_connection = Connection(base_url=org, creds=credentials)
|
||||
azure_devops_client = azure_devops_connection.clients.get_git_client()
|
||||
|
||||
return azure_devops_client
|
||||
|
||||
def _get_repo(self):
|
||||
if self.repo is None:
|
||||
self.repo = self.azure_devops_client.get_repository(
|
||||
project=self.workspace_slug, repository_id=self.repo_slug
|
||||
)
|
||||
return self.repo
|
||||
|
||||
def _get_pr(self):
|
||||
self.pr = self.azure_devops_client.get_pull_request_by_id(
|
||||
pull_request_id=self.pr_num, project=self.workspace_slug
|
||||
)
|
||||
return self.pr
|
||||
|
||||
def get_commit_messages(self):
|
||||
return "" # not implemented yet
|
||||
|
||||
def get_pr_id(self):
|
||||
try:
|
||||
pr_id = f"{self.workspace_slug}/{self.repo_slug}/{self.pr_num}"
|
||||
return pr_id
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().error(f"Failed to get pr id, error: {e}")
|
||||
return ""
|
||||
|
||||
def publish_file_comments(self, file_comments: list) -> bool:
|
||||
pass
|
||||
561
apps/utils/pr_agent/git_providers/bitbucket_provider.py
Normal file
561
apps/utils/pr_agent/git_providers/bitbucket_provider.py
Normal file
@ -0,0 +1,561 @@
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from atlassian.bitbucket import Cloud
|
||||
from starlette_context import context
|
||||
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
from ..algo.file_filter import filter_ignored
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import find_line_number_of_relevant_line_in_file
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
|
||||
|
||||
|
||||
def _gef_filename(diff):
|
||||
if diff.new.path:
|
||||
return diff.new.path
|
||||
return diff.old.path
|
||||
|
||||
|
||||
class BitbucketProvider(GitProvider):
|
||||
def __init__(
|
||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
||||
):
|
||||
s = requests.Session()
|
||||
try:
|
||||
bearer = context.get("bitbucket_bearer_token", None)
|
||||
s.headers["Authorization"] = f"Bearer {bearer}"
|
||||
except Exception:
|
||||
s.headers[
|
||||
"Authorization"
|
||||
] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
|
||||
s.headers["Content-Type"] = "application/json"
|
||||
self.headers = s.headers
|
||||
self.bitbucket_client = Cloud(session=s)
|
||||
self.max_comment_length = 31000
|
||||
self.workspace_slug = None
|
||||
self.repo_slug = None
|
||||
self.repo = None
|
||||
self.pr_num = None
|
||||
self.pr = None
|
||||
self.pr_url = pr_url
|
||||
self.temp_comments = []
|
||||
self.incremental = incremental
|
||||
self.diff_files = None
|
||||
self.git_files = None
|
||||
if pr_url:
|
||||
self.set_pr(pr_url)
|
||||
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"]["comments"]["href"]
|
||||
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"]['self']['href']
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||
f"{self.pr.destination_branch}/.pr_agent.toml")
|
||||
response = requests.request("GET", url, headers=self.headers)
|
||||
if response.status_code == 404: # not found
|
||||
return ""
|
||||
contents = response.text.encode('utf-8')
|
||||
return contents
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
"""
|
||||
Publishes code suggestions as comments on the PR.
|
||||
"""
|
||||
post_parameters_list = []
|
||||
for suggestion in code_suggestions:
|
||||
body = suggestion["body"]
|
||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
||||
if original_suggestion:
|
||||
try:
|
||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
||||
improved_code.split('\n'), n=999)
|
||||
patch_orig = "\n".join(diff)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
|
||||
continue
|
||||
|
||||
relevant_file = suggestion["relevant_file"]
|
||||
relevant_lines_start = suggestion["relevant_lines_start"]
|
||||
relevant_lines_end = suggestion["relevant_lines_end"]
|
||||
|
||||
if not relevant_lines_start or relevant_lines_start == -1:
|
||||
get_logger().exception(
|
||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end < relevant_lines_start:
|
||||
get_logger().exception(
|
||||
f"Failed to publish code suggestion, "
|
||||
f"relevant_lines_end is {relevant_lines_end} and "
|
||||
f"relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end > relevant_lines_start:
|
||||
post_parameters = {
|
||||
"body": body,
|
||||
"path": relevant_file,
|
||||
"line": relevant_lines_end,
|
||||
"start_line": relevant_lines_start,
|
||||
"start_side": "RIGHT",
|
||||
}
|
||||
else: # API is different for single line comments
|
||||
post_parameters = {
|
||||
"body": body,
|
||||
"path": relevant_file,
|
||||
"line": relevant_lines_start,
|
||||
"side": "RIGHT",
|
||||
}
|
||||
post_parameters_list.append(post_parameters)
|
||||
|
||||
try:
|
||||
self.publish_inline_comments(post_parameters_list)
|
||||
return True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Bitbucket failed to publish code suggestion, error: {e}")
|
||||
return False
|
||||
|
||||
def publish_file_comments(self, file_comments: list) -> bool:
|
||||
pass
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'publish_inline_comments', 'get_labels', 'gfm_markdown',
|
||||
'publish_file_comments']:
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_pr(self, pr_url: str):
|
||||
self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url)
|
||||
self.pr = self._get_pr()
|
||||
|
||||
def get_files(self):
|
||||
try:
|
||||
git_files = context.get("git_files", None)
|
||||
if git_files:
|
||||
return git_files
|
||||
self.git_files = [_gef_filename(diff) for diff in self.pr.diffstat()]
|
||||
context["git_files"] = self.git_files
|
||||
return self.git_files
|
||||
except Exception:
|
||||
if not self.git_files:
|
||||
self.git_files = [_gef_filename(diff) for diff in self.pr.diffstat()]
|
||||
return self.git_files
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
if self.diff_files:
|
||||
return self.diff_files
|
||||
|
||||
diffs_original = list(self.pr.diffstat())
|
||||
diffs = filter_ignored(diffs_original, 'bitbucket')
|
||||
if diffs != diffs_original:
|
||||
try:
|
||||
names_original = [d.new.path for d in diffs_original]
|
||||
names_kept = [d.new.path for d in diffs]
|
||||
names_filtered = list(set(names_original) - set(names_kept))
|
||||
get_logger().info(f"Filtered out [ignore] files for PR", extra={
|
||||
'original_files': names_original,
|
||||
'names_kept': names_kept,
|
||||
'names_filtered': names_filtered
|
||||
|
||||
})
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# get the pr patches
|
||||
try:
|
||||
pr_patches = self.pr.diff()
|
||||
except Exception as e:
|
||||
# Try different encodings if UTF-8 fails
|
||||
get_logger().warning(f"Failed to decode PR patch with utf-8, error: {e}")
|
||||
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
|
||||
pr_patches = None
|
||||
for encoding in encodings_to_try:
|
||||
try:
|
||||
pr_patches = self.pr.diff(encoding=encoding)
|
||||
get_logger().info(f"Successfully decoded PR patch with encoding {encoding}")
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if pr_patches is None:
|
||||
raise ValueError(f"Failed to decode PR patch with encodings {encodings_to_try}")
|
||||
|
||||
diff_split = ["diff --git" + x for x in pr_patches.split("diff --git") if x.strip()]
|
||||
# filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs'
|
||||
if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split):
|
||||
diff_split = [diff_split[i] for i in range(len(diff_split)) if diffs_original[i] in diffs]
|
||||
if len(diff_split) != len(diffs):
|
||||
get_logger().error(f"Error - failed to split the diff into {len(diffs)} parts")
|
||||
return []
|
||||
# bitbucket diff has a header for each file, we need to remove it:
|
||||
# "diff --git filename
|
||||
# new file mode 100644 (optional)
|
||||
# index caa56f0..61528d7 100644
|
||||
# --- a/pr_agent/cli_pip.py
|
||||
# +++ b/pr_agent/cli_pip.py
|
||||
# @@ -... @@"
|
||||
for i, _ in enumerate(diff_split):
|
||||
diff_split_lines = diff_split[i].splitlines()
|
||||
if (len(diff_split_lines) >= 6) and \
|
||||
((diff_split_lines[2].startswith("---") and
|
||||
diff_split_lines[3].startswith("+++") and
|
||||
diff_split_lines[4].startswith("@@")) or
|
||||
(diff_split_lines[3].startswith("---") and # new or deleted file
|
||||
diff_split_lines[4].startswith("+++") and
|
||||
diff_split_lines[5].startswith("@@"))):
|
||||
diff_split[i] = "\n".join(diff_split_lines[4:])
|
||||
else:
|
||||
if diffs[i].data.get('lines_added', 0) == 0 and diffs[i].data.get('lines_removed', 0) == 0:
|
||||
diff_split[i] = ""
|
||||
elif len(diff_split_lines) <= 3:
|
||||
diff_split[i] = ""
|
||||
get_logger().info(f"Disregarding empty diff for file {_gef_filename(diffs[i])}")
|
||||
else:
|
||||
get_logger().warning(f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}")
|
||||
diff_split[i] = ""
|
||||
|
||||
invalid_files_names = []
|
||||
diff_files = []
|
||||
counter_valid = 0
|
||||
# get full files
|
||||
for index, diff in enumerate(diffs):
|
||||
file_path = _gef_filename(diff)
|
||||
if not is_valid_file(file_path):
|
||||
invalid_files_names.append(file_path)
|
||||
continue
|
||||
|
||||
try:
|
||||
counter_valid += 1
|
||||
if get_settings().get("bitbucket_app.avoid_full_files", False):
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
elif counter_valid < MAX_FILES_ALLOWED_FULL // 2: # factor 2 because bitbucket has limited API calls
|
||||
if diff.old.get_data("links"):
|
||||
original_file_content_str = self._get_pr_file_content(
|
||||
diff.old.get_data("links")['self']['href'])
|
||||
else:
|
||||
original_file_content_str = ""
|
||||
if diff.new.get_data("links"):
|
||||
new_file_content_str = self._get_pr_file_content(diff.new.get_data("links")['self']['href'])
|
||||
else:
|
||||
new_file_content_str = ""
|
||||
else:
|
||||
if counter_valid == MAX_FILES_ALLOWED_FULL // 2:
|
||||
get_logger().info(
|
||||
f"Bitbucket too many files in PR, will avoid loading full content for rest of files")
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Error - bitbucket failed to get file content, error: {e}")
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
|
||||
file_patch_canonic_structure = FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
diff_split[index],
|
||||
file_path,
|
||||
)
|
||||
|
||||
if diff.data['status'] == 'added':
|
||||
file_patch_canonic_structure.edit_type = EDIT_TYPE.ADDED
|
||||
elif diff.data['status'] == 'removed':
|
||||
file_patch_canonic_structure.edit_type = EDIT_TYPE.DELETED
|
||||
elif diff.data['status'] == 'modified':
|
||||
file_patch_canonic_structure.edit_type = EDIT_TYPE.MODIFIED
|
||||
elif diff.data['status'] == 'renamed':
|
||||
file_patch_canonic_structure.edit_type = EDIT_TYPE.RENAMED
|
||||
diff_files.append(file_patch_canonic_structure)
|
||||
|
||||
if invalid_files_names:
|
||||
get_logger().info(f"Disregarding files with invalid extensions:\n{invalid_files_names}")
|
||||
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def get_latest_commit_url(self):
|
||||
return self.pr.data['source']['commit']['links']['html']['href']
|
||||
|
||||
def get_comment_url(self, comment):
|
||||
return comment.data['links']['html']['href']
|
||||
|
||||
def publish_persistent_comment(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
try:
|
||||
for comment in self.pr.comments():
|
||||
body = comment.raw
|
||||
if initial_header in body:
|
||||
latest_commit_url = self.get_latest_commit_url()
|
||||
comment_url = self.get_comment_url(comment)
|
||||
if update_header:
|
||||
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
||||
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
|
||||
else:
|
||||
pr_comment_updated = pr_comment
|
||||
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
|
||||
d = {"content": {"raw": pr_comment_updated}}
|
||||
response = comment._update_data(comment.put(None, data=d))
|
||||
if final_update_message:
|
||||
self.publish_comment(
|
||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
|
||||
return
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||
pass
|
||||
self.publish_comment(pr_comment)
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
||||
return None
|
||||
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length)
|
||||
comment = self.pr.comment(pr_comment)
|
||||
if is_temporary:
|
||||
self.temp_comments.append(comment["id"])
|
||||
return comment
|
||||
|
||||
def edit_comment(self, comment, body: str):
|
||||
try:
|
||||
body = self.limit_output_characters(body, self.max_comment_length)
|
||||
comment.update(body)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to update comment, error: {e}")
|
||||
|
||||
def remove_initial_comment(self):
|
||||
try:
|
||||
for comment in self.temp_comments:
|
||||
self.remove_comment(comment)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove temp comments, error: {e}")
|
||||
|
||||
def remove_comment(self, comment):
|
||||
try:
|
||||
self.pr.delete(f"comments/{comment}")
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove comment, error: {e}")
|
||||
|
||||
# function to create_inline_comment
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
body = self.limit_output_characters(body, self.max_comment_length)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position)
|
||||
if position == -1:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
subject_type = "FILE"
|
||||
else:
|
||||
subject_type = "LINE"
|
||||
path = relevant_file.strip()
|
||||
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
|
||||
|
||||
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
|
||||
comment = self.limit_output_characters(comment, self.max_comment_length)
|
||||
payload = json.dumps({
|
||||
"content": {
|
||||
"raw": comment,
|
||||
},
|
||||
"inline": {
|
||||
"to": from_line,
|
||||
"path": file
|
||||
},
|
||||
})
|
||||
response = requests.request(
|
||||
"POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers
|
||||
)
|
||||
return response
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
if relevant_line_start == -1:
|
||||
link = f"{self.pr_url}/#L{relevant_file}"
|
||||
else:
|
||||
link = f"{self.pr_url}/#L{relevant_file}T{relevant_line_start}"
|
||||
return link
|
||||
|
||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||
try:
|
||||
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
|
||||
relevant_line_str = suggestion['relevant_line'].rstrip()
|
||||
if not relevant_line_str:
|
||||
return ""
|
||||
|
||||
diff_files = self.get_diff_files()
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
||||
(diff_files, relevant_file, relevant_line_str)
|
||||
|
||||
if absolute_position != -1 and self.pr_url:
|
||||
link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}"
|
||||
return link
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link, error: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
for comment in comments:
|
||||
if 'position' in comment:
|
||||
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
|
||||
elif 'start_line' in comment: # multi-line comment
|
||||
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
||||
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
|
||||
elif 'line' in comment: # single-line comment
|
||||
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
|
||||
else:
|
||||
get_logger().error(f"Could not publish inline comment {comment}")
|
||||
|
||||
def get_title(self):
|
||||
return self.pr.title
|
||||
|
||||
def get_languages(self):
|
||||
languages = {self._get_repo().get_data("language"): 0}
|
||||
return languages
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.pr.source_branch
|
||||
|
||||
def get_pr_owner_id(self) -> str | None:
|
||||
return self.workspace_slug
|
||||
|
||||
def get_pr_description_full(self):
|
||||
return self.pr.description
|
||||
|
||||
def get_user_id(self):
|
||||
return 0
|
||||
|
||||
def get_issue_comments(self):
|
||||
raise NotImplementedError(
|
||||
"Bitbucket provider does not support issue comments yet"
|
||||
)
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
|
||||
parsed_url = urlparse(pr_url)
|
||||
|
||||
if "bitbucket.org" not in parsed_url.netloc:
|
||||
raise ValueError("The provided URL is not a valid Bitbucket URL")
|
||||
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
|
||||
if len(path_parts) < 4 or path_parts[2] != "pull-requests":
|
||||
raise ValueError(
|
||||
"The provided URL does not appear to be a Bitbucket PR URL"
|
||||
)
|
||||
|
||||
workspace_slug = path_parts[0]
|
||||
repo_slug = path_parts[1]
|
||||
try:
|
||||
pr_number = int(path_parts[3])
|
||||
except ValueError as e:
|
||||
raise ValueError("Unable to convert PR number to integer") from e
|
||||
|
||||
return workspace_slug, repo_slug, pr_number
|
||||
|
||||
def _get_repo(self):
|
||||
if self.repo is None:
|
||||
self.repo = self.bitbucket_client.workspaces.get(
|
||||
self.workspace_slug
|
||||
).repositories.get(self.repo_slug)
|
||||
return self.repo
|
||||
|
||||
def _get_pr(self):
|
||||
return self._get_repo().pullrequests.get(self.pr_num)
|
||||
|
||||
def get_pr_file_content(self, file_path: str, branch: str) -> str:
|
||||
try:
|
||||
if branch == self.pr.source_branch:
|
||||
branch = self.pr.data["source"]["commit"]["hash"]
|
||||
elif branch == self.pr.destination_branch:
|
||||
branch = self.pr.data["destination"]["commit"]["hash"]
|
||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||
f"{branch}/{file_path}")
|
||||
response = requests.request("GET", url, headers=self.headers)
|
||||
if response.status_code == 404: # not found
|
||||
return ""
|
||||
contents = response.text
|
||||
return contents
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None:
|
||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/")
|
||||
if not message:
|
||||
if contents:
|
||||
message = f"Update {file_path}"
|
||||
else:
|
||||
message = f"Create {file_path}"
|
||||
files = {file_path: contents}
|
||||
data = {
|
||||
"message": message,
|
||||
"branch": branch
|
||||
}
|
||||
headers = {'Authorization': self.headers['Authorization']} if 'Authorization' in self.headers else {}
|
||||
try:
|
||||
requests.request("POST", url, headers=headers, data=data, files=files)
|
||||
except Exception:
|
||||
get_logger().exception(f"Failed to create empty file {file_path} in branch {branch}")
|
||||
|
||||
def _get_pr_file_content(self, remote_link: str):
|
||||
try:
|
||||
response = requests.request("GET", remote_link, headers=self.headers)
|
||||
if response.status_code == 404: # not found
|
||||
return ""
|
||||
contents = response.text
|
||||
return contents
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def get_commit_messages(self):
|
||||
return "" # not implemented yet
|
||||
|
||||
# bitbucket does not support labels
|
||||
def publish_description(self, pr_title: str, description: str):
|
||||
payload = json.dumps({
|
||||
"description": description,
|
||||
"title": pr_title
|
||||
|
||||
})
|
||||
|
||||
response = requests.request("PUT", self.bitbucket_pull_request_api_url, headers=self.headers, data=payload)
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
get_logger().info(f"Failed to update description, error code: {response.status_code}")
|
||||
except:
|
||||
pass
|
||||
return response
|
||||
|
||||
# bitbucket does not support labels
|
||||
def publish_labels(self, pr_types: list):
|
||||
pass
|
||||
|
||||
# bitbucket does not support labels
|
||||
def get_pr_labels(self, update=False):
|
||||
pass
|
||||
483
apps/utils/pr_agent/git_providers/bitbucket_server_provider.py
Normal file
483
apps/utils/pr_agent/git_providers/bitbucket_server_provider.py
Normal file
@ -0,0 +1,483 @@
|
||||
import difflib
|
||||
import re
|
||||
|
||||
from packaging.version import parse as parse_version
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import quote_plus, urlparse
|
||||
|
||||
from atlassian.bitbucket import Bitbucket
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from ..algo.git_patch_processing import decode_if_bytes
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from ..algo.utils import (find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff)
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import GitProvider
|
||||
|
||||
|
||||
class BitbucketServerProvider(GitProvider):
|
||||
def __init__(
|
||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False,
|
||||
bitbucket_client: Optional[Bitbucket] = None,
|
||||
):
|
||||
self.bitbucket_server_url = None
|
||||
self.workspace_slug = None
|
||||
self.repo_slug = None
|
||||
self.repo = None
|
||||
self.pr_num = None
|
||||
self.pr = None
|
||||
self.pr_url = pr_url
|
||||
self.temp_comments = []
|
||||
self.incremental = incremental
|
||||
self.diff_files = None
|
||||
self.bitbucket_pull_request_api_url = pr_url
|
||||
|
||||
self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url)
|
||||
self.bitbucket_client = bitbucket_client or Bitbucket(url=self.bitbucket_server_url,
|
||||
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN",
|
||||
None))
|
||||
try:
|
||||
self.bitbucket_api_version = parse_version(self.bitbucket_client.get("rest/api/1.0/application-properties").get('version'))
|
||||
except Exception:
|
||||
self.bitbucket_api_version = None
|
||||
|
||||
if pr_url:
|
||||
self.set_pr(pr_url)
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch())
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPError):
|
||||
if e.response.status_code == 404: # not found
|
||||
return ""
|
||||
|
||||
get_logger().error(f"Failed to load .pr_agent.toml file, error: {e}")
|
||||
return ""
|
||||
|
||||
def get_pr_id(self):
|
||||
return self.pr_num
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
"""
|
||||
Publishes code suggestions as comments on the PR.
|
||||
"""
|
||||
post_parameters_list = []
|
||||
for suggestion in code_suggestions:
|
||||
body = suggestion["body"]
|
||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
||||
if original_suggestion:
|
||||
try:
|
||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
||||
improved_code.split('\n'), n=999)
|
||||
patch_orig = "\n".join(diff)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
|
||||
continue
|
||||
relevant_file = suggestion["relevant_file"]
|
||||
relevant_lines_start = suggestion["relevant_lines_start"]
|
||||
relevant_lines_end = suggestion["relevant_lines_end"]
|
||||
|
||||
if not relevant_lines_start or relevant_lines_start == -1:
|
||||
get_logger().warning(
|
||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end < relevant_lines_start:
|
||||
get_logger().warning(
|
||||
f"Failed to publish code suggestion, "
|
||||
f"relevant_lines_end is {relevant_lines_end} and "
|
||||
f"relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end > relevant_lines_start:
|
||||
# Bitbucket does not support multi-line suggestions so use a code block instead - https://jira.atlassian.com/browse/BSERV-4553
|
||||
body = body.replace("```suggestion", "```")
|
||||
post_parameters = {
|
||||
"body": body,
|
||||
"path": relevant_file,
|
||||
"line": relevant_lines_end,
|
||||
"start_line": relevant_lines_start,
|
||||
"start_side": "RIGHT",
|
||||
}
|
||||
else: # API is different for single line comments
|
||||
post_parameters = {
|
||||
"body": body,
|
||||
"path": relevant_file,
|
||||
"line": relevant_lines_start,
|
||||
"side": "RIGHT",
|
||||
}
|
||||
post_parameters_list.append(post_parameters)
|
||||
|
||||
try:
|
||||
self.publish_inline_comments(post_parameters_list)
|
||||
return True
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().error(f"Failed to publish code suggestion, error: {e}")
|
||||
return False
|
||||
|
||||
def publish_file_comments(self, file_comments: list) -> bool:
|
||||
pass
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'get_labels', 'gfm_markdown', 'publish_file_comments']:
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_pr(self, pr_url: str):
|
||||
self.workspace_slug, self.repo_slug, self.pr_num = self._parse_pr_url(pr_url)
|
||||
self.pr = self._get_pr()
|
||||
|
||||
def get_file(self, path: str, commit_id: str):
|
||||
file_content = ""
|
||||
try:
|
||||
file_content = self.bitbucket_client.get_content_of_file(self.workspace_slug,
|
||||
self.repo_slug,
|
||||
path,
|
||||
commit_id)
|
||||
except HTTPError as e:
|
||||
get_logger().debug(f"File {path} not found at commit id: {commit_id}")
|
||||
return file_content
|
||||
|
||||
def get_files(self):
|
||||
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
|
||||
diffstat = [change["path"]['toString'] for change in changes]
|
||||
return diffstat
|
||||
|
||||
#gets the best common ancestor: https://git-scm.com/docs/git-merge-base
|
||||
@staticmethod
|
||||
def get_best_common_ancestor(source_commits_list, destination_commits_list, guaranteed_common_ancestor) -> str:
|
||||
destination_commit_hashes = {commit['id'] for commit in destination_commits_list} | {guaranteed_common_ancestor}
|
||||
|
||||
for commit in source_commits_list:
|
||||
for parent_commit in commit['parents']:
|
||||
if parent_commit['id'] in destination_commit_hashes:
|
||||
return parent_commit['id']
|
||||
|
||||
return guaranteed_common_ancestor
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
if self.diff_files:
|
||||
return self.diff_files
|
||||
|
||||
head_sha = self.pr.fromRef['latestCommit']
|
||||
|
||||
# if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation
|
||||
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("8.16"):
|
||||
try:
|
||||
base_sha = self.bitbucket_client.get(self._get_merge_base())['id']
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}")
|
||||
raise e
|
||||
else:
|
||||
source_commits_list = list(self.bitbucket_client.get_pull_requests_commits(
|
||||
self.workspace_slug,
|
||||
self.repo_slug,
|
||||
self.pr_num
|
||||
))
|
||||
# if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor
|
||||
base_sha = source_commits_list[-1]['parents'][0]['id']
|
||||
# if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha
|
||||
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("7.0"):
|
||||
try:
|
||||
destination_commits = list(
|
||||
self.bitbucket_client.get_commits(self.workspace_slug, self.repo_slug, base_sha,
|
||||
self.pr.toRef['latestCommit']))
|
||||
base_sha = self.get_best_common_ancestor(source_commits_list, destination_commits, base_sha)
|
||||
except Exception as e:
|
||||
get_logger().error(
|
||||
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}")
|
||||
raise e
|
||||
|
||||
diff_files = []
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
|
||||
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
|
||||
for change in changes:
|
||||
file_path = change['path']['toString']
|
||||
if not is_valid_file(file_path.split("/")[-1]):
|
||||
get_logger().info(f"Skipping a non-code file: {file_path}")
|
||||
continue
|
||||
|
||||
match change['type']:
|
||||
case 'ADD':
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
new_file_content_str = self.get_file(file_path, head_sha)
|
||||
new_file_content_str = decode_if_bytes(new_file_content_str)
|
||||
original_file_content_str = ""
|
||||
case 'DELETE':
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
new_file_content_str = ""
|
||||
original_file_content_str = self.get_file(file_path, base_sha)
|
||||
original_file_content_str = decode_if_bytes(original_file_content_str)
|
||||
case 'RENAME':
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
case _:
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
original_file_content_str = self.get_file(file_path, base_sha)
|
||||
original_file_content_str = decode_if_bytes(original_file_content_str)
|
||||
new_file_content_str = self.get_file(file_path, head_sha)
|
||||
new_file_content_str = decode_if_bytes(new_file_content_str)
|
||||
|
||||
patch = load_large_diff(file_path, new_file_content_str, original_file_content_str, show_warning=False)
|
||||
|
||||
diff_files.append(
|
||||
FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
patch,
|
||||
file_path,
|
||||
edit_type=edit_type,
|
||||
)
|
||||
)
|
||||
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if not is_temporary:
|
||||
self.bitbucket_client.add_pull_request_comment(self.workspace_slug, self.repo_slug, self.pr_num, pr_comment)
|
||||
|
||||
def remove_initial_comment(self):
|
||||
try:
|
||||
for comment in self.temp_comments:
|
||||
self.remove_comment(comment)
|
||||
except ValueError as e:
|
||||
get_logger().exception(f"Failed to remove temp comments, error: {e}")
|
||||
|
||||
def remove_comment(self, comment):
|
||||
pass
|
||||
|
||||
# function to create_inline_comment
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position
|
||||
)
|
||||
if position == -1:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
subject_type = "FILE"
|
||||
else:
|
||||
subject_type = "LINE"
|
||||
path = relevant_file.strip()
|
||||
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
|
||||
|
||||
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
|
||||
payload = {
|
||||
"text": comment,
|
||||
"severity": "NORMAL",
|
||||
"anchor": {
|
||||
"diffType": "EFFECTIVE",
|
||||
"path": file,
|
||||
"lineType": "ADDED",
|
||||
"line": from_line,
|
||||
"fileType": "TO"
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
self.bitbucket_client.post(self._get_pr_comments_path(), data=payload)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}")
|
||||
raise e
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
if relevant_line_start == -1:
|
||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}"
|
||||
else:
|
||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={relevant_line_start}"
|
||||
return link
|
||||
|
||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||
try:
|
||||
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
|
||||
relevant_line_str = suggestion['relevant_line'].rstrip()
|
||||
if not relevant_line_str:
|
||||
return ""
|
||||
|
||||
diff_files = self.get_diff_files()
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
||||
(diff_files, relevant_file, relevant_line_str)
|
||||
|
||||
if absolute_position != -1:
|
||||
if self.pr:
|
||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
|
||||
return link
|
||||
else:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link to '{relevant_file}' since PR not set")
|
||||
else:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link to '{relevant_file}' since position not found")
|
||||
|
||||
if absolute_position != -1 and self.pr_url:
|
||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
|
||||
return link
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link to '{relevant_file}', error: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
for comment in comments:
|
||||
if 'position' in comment:
|
||||
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
|
||||
elif 'start_line' in comment: # multi-line comment
|
||||
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
||||
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
|
||||
elif 'line' in comment: # single-line comment
|
||||
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
|
||||
else:
|
||||
get_logger().error(f"Could not publish inline comment: {comment}")
|
||||
|
||||
def get_title(self):
|
||||
return self.pr.title
|
||||
|
||||
def get_languages(self):
|
||||
return {"yaml": 0} # devops LOL
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.pr.fromRef['displayId']
|
||||
|
||||
def get_pr_owner_id(self) -> str | None:
|
||||
return self.workspace_slug
|
||||
|
||||
def get_pr_description_full(self):
|
||||
if hasattr(self.pr, "description"):
|
||||
return self.pr.description
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_user_id(self):
|
||||
return 0
|
||||
|
||||
def get_issue_comments(self):
|
||||
raise NotImplementedError(
|
||||
"Bitbucket provider does not support issue comments yet"
|
||||
)
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _parse_bitbucket_server(url: str) -> str:
|
||||
# pr url format: f"{bitbucket_server}/projects/{project_name}/repos/{repository_name}/pull-requests/{pr_id}"
|
||||
parsed_url = urlparse(url)
|
||||
server_path = parsed_url.path.split("/projects/")
|
||||
if len(server_path) > 1:
|
||||
server_path = server_path[0].strip("/")
|
||||
return f"{parsed_url.scheme}://{parsed_url.netloc}/{server_path}".strip("/")
|
||||
return f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
@staticmethod
|
||||
def _parse_pr_url(pr_url: str) -> Tuple[str, str, int]:
|
||||
# pr url format: f"{bitbucket_server}/projects/{project_name}/repos/{repository_name}/pull-requests/{pr_id}"
|
||||
parsed_url = urlparse(pr_url)
|
||||
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
|
||||
try:
|
||||
projects_index = path_parts.index("projects")
|
||||
except ValueError:
|
||||
projects_index = -1
|
||||
|
||||
try:
|
||||
users_index = path_parts.index("users")
|
||||
except ValueError:
|
||||
users_index = -1
|
||||
|
||||
if projects_index == -1 and users_index == -1:
|
||||
raise ValueError(f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL")
|
||||
|
||||
if projects_index != -1:
|
||||
path_parts = path_parts[projects_index:]
|
||||
else:
|
||||
path_parts = path_parts[users_index:]
|
||||
|
||||
if len(path_parts) < 6 or path_parts[2] != "repos" or path_parts[4] != "pull-requests":
|
||||
raise ValueError(
|
||||
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
|
||||
)
|
||||
|
||||
workspace_slug = path_parts[1]
|
||||
if users_index != -1:
|
||||
workspace_slug = f"~{workspace_slug}"
|
||||
repo_slug = path_parts[3]
|
||||
try:
|
||||
pr_number = int(path_parts[5])
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Unable to convert PR number '{path_parts[5]}' to integer") from e
|
||||
|
||||
return workspace_slug, repo_slug, pr_number
|
||||
|
||||
def _get_repo(self):
|
||||
if self.repo is None:
|
||||
self.repo = self.bitbucket_client.get_repo(self.workspace_slug, self.repo_slug)
|
||||
return self.repo
|
||||
|
||||
def _get_pr(self):
|
||||
try:
|
||||
pr = self.bitbucket_client.get_pull_request(self.workspace_slug, self.repo_slug,
|
||||
pull_request_id=self.pr_num)
|
||||
return type('new_dict', (object,), pr)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get pull request, error: {e}")
|
||||
raise e
|
||||
|
||||
def _get_pr_file_content(self, remote_link: str):
|
||||
return ""
|
||||
|
||||
def get_commit_messages(self):
|
||||
return ""
|
||||
|
||||
# bitbucket does not support labels
|
||||
def publish_description(self, pr_title: str, description: str):
|
||||
payload = {
|
||||
"version": self.pr.version,
|
||||
"description": description,
|
||||
"title": pr_title,
|
||||
"reviewers": self.pr.reviewers # needs to be sent otherwise gets wiped
|
||||
}
|
||||
try:
|
||||
self.bitbucket_client.update_pull_request(self.workspace_slug, self.repo_slug, str(self.pr_num), payload)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to update pull request, error: {e}")
|
||||
raise e
|
||||
|
||||
# bitbucket does not support labels
|
||||
def publish_labels(self, pr_types: list):
|
||||
pass
|
||||
|
||||
# bitbucket does not support labels
|
||||
def get_pr_labels(self, update=False):
|
||||
pass
|
||||
|
||||
def _get_pr_comments_path(self):
|
||||
return f"rest/api/latest/projects/{self.workspace_slug}/repos/{self.repo_slug}/pull-requests/{self.pr_num}/comments"
|
||||
|
||||
def _get_merge_base(self):
|
||||
return f"rest/api/latest/projects/{self.workspace_slug}/repos/{self.repo_slug}/pull-requests/{self.pr_num}/merge-base"
|
||||
277
apps/utils/pr_agent/git_providers/codecommit_client.py
Normal file
277
apps/utils/pr_agent/git_providers/codecommit_client.py
Normal file
@ -0,0 +1,277 @@
|
||||
import boto3
|
||||
import botocore
|
||||
|
||||
|
||||
class CodeCommitDifferencesResponse:
|
||||
"""
|
||||
CodeCommitDifferencesResponse is the response object returned from our get_differences() function.
|
||||
It maps the JSON response to member variables of this class.
|
||||
"""
|
||||
|
||||
def __init__(self, json: dict):
|
||||
before_blob = json.get("beforeBlob", {})
|
||||
after_blob = json.get("afterBlob", {})
|
||||
|
||||
self.before_blob_id = before_blob.get("blobId", "")
|
||||
self.before_blob_path = before_blob.get("path", "")
|
||||
self.after_blob_id = after_blob.get("blobId", "")
|
||||
self.after_blob_path = after_blob.get("path", "")
|
||||
self.change_type = json.get("changeType", "")
|
||||
|
||||
|
||||
class CodeCommitPullRequestResponse:
|
||||
"""
|
||||
CodeCommitPullRequestResponse is the response object returned from our get_pr() function.
|
||||
It maps the JSON response to member variables of this class.
|
||||
"""
|
||||
|
||||
def __init__(self, json: dict):
|
||||
self.title = json.get("title", "")
|
||||
self.description = json.get("description", "")
|
||||
|
||||
self.targets = []
|
||||
for target in json.get("pullRequestTargets", []):
|
||||
self.targets.append(CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target))
|
||||
|
||||
class CodeCommitPullRequestTarget:
|
||||
"""
|
||||
CodeCommitPullRequestTarget is a subclass of CodeCommitPullRequestResponse that
|
||||
holds details about an individual target commit.
|
||||
"""
|
||||
|
||||
def __init__(self, json: dict):
|
||||
self.source_commit = json.get("sourceCommit", "")
|
||||
self.source_branch = json.get("sourceReference", "")
|
||||
self.destination_commit = json.get("destinationCommit", "")
|
||||
self.destination_branch = json.get("destinationReference", "")
|
||||
|
||||
|
||||
class CodeCommitClient:
|
||||
"""
|
||||
CodeCommitClient is a wrapper around the AWS boto3 SDK for the CodeCommit client
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.boto_client = None
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ["gfm_markdown"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _connect_boto_client(self):
|
||||
try:
|
||||
self.boto_client = boto3.client("codecommit")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e
|
||||
|
||||
def get_differences(self, repo_name: int, destination_commit: str, source_commit: str):
|
||||
"""
|
||||
Get the differences between two commits in CodeCommit.
|
||||
|
||||
Args:
|
||||
- repo_name: Name of the repository
|
||||
- destination_commit: Commit hash you want to merge into (the "before" hash) (usually on the main or master branch)
|
||||
- source_commit: Commit hash of the code you are adding (the "after" branch)
|
||||
|
||||
Returns:
|
||||
- List of CodeCommitDifferencesResponse objects
|
||||
|
||||
Boto3 Documentation:
|
||||
- aws codecommit get-differences
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/get_differences.html
|
||||
"""
|
||||
if self.boto_client is None:
|
||||
self._connect_boto_client()
|
||||
|
||||
# The differences response from AWS is paginated, so we need to iterate through the pages to get all the differences.
|
||||
differences = []
|
||||
try:
|
||||
paginator = self.boto_client.get_paginator("get_differences")
|
||||
for page in paginator.paginate(
|
||||
repositoryName=repo_name,
|
||||
beforeCommitSpecifier=destination_commit,
|
||||
afterCommitSpecifier=source_commit,
|
||||
):
|
||||
differences.extend(page.get("differences", []))
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}") from e
|
||||
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
|
||||
|
||||
output = []
|
||||
for json in differences:
|
||||
output.append(CodeCommitDifferencesResponse(json))
|
||||
return output
|
||||
|
||||
def get_file(self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False):
|
||||
"""
|
||||
Retrieve a file from CodeCommit.
|
||||
|
||||
Args:
|
||||
- repo_name: Name of the repository
|
||||
- file_path: Path to the file you are retrieving
|
||||
- sha_hash: Commit hash of the file you are retrieving
|
||||
|
||||
Returns:
|
||||
- File contents
|
||||
|
||||
Boto3 Documentation:
|
||||
- aws codecommit get_file
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/get_file.html
|
||||
"""
|
||||
if not file_path:
|
||||
return ""
|
||||
|
||||
if self.boto_client is None:
|
||||
self._connect_boto_client()
|
||||
|
||||
try:
|
||||
response = self.boto_client.get_file(repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
|
||||
# if the file does not exist, but is flagged as optional, then return an empty string
|
||||
if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException':
|
||||
return ""
|
||||
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
|
||||
if "fileContent" not in response:
|
||||
raise ValueError(f"File content is empty for file: {file_path}")
|
||||
|
||||
return response.get("fileContent", "")
|
||||
|
||||
def get_pr(self, repo_name: str, pr_number: int):
|
||||
"""
|
||||
Get a information about a CodeCommit PR.
|
||||
|
||||
Args:
|
||||
- repo_name: Name of the repository
|
||||
- pr_number: The PR number you are requesting
|
||||
|
||||
Returns:
|
||||
- CodeCommitPullRequestResponse object
|
||||
|
||||
Boto3 Documentation:
|
||||
- aws codecommit get_pull_request
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/get_pull_request.html
|
||||
"""
|
||||
if self.boto_client is None:
|
||||
self._connect_boto_client()
|
||||
|
||||
try:
|
||||
response = self.boto_client.get_pull_request(pullRequestId=str(pr_number))
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}") from e
|
||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}: boto client error") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e
|
||||
|
||||
if "pullRequest" not in response:
|
||||
raise ValueError("CodeCommit PR number not found: {pr_number}")
|
||||
|
||||
return CodeCommitPullRequestResponse(response.get("pullRequest", {}))
|
||||
|
||||
def publish_description(self, pr_number: int, pr_title: str, pr_body: str):
|
||||
"""
|
||||
Set the title and description on a pull request
|
||||
|
||||
Args:
|
||||
- pr_number: the AWS CodeCommit pull request number
|
||||
- pr_title: title of the pull request
|
||||
- pr_body: body of the pull request
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Boto3 Documentation:
|
||||
- aws codecommit update_pull_request_title
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/update_pull_request_title.html
|
||||
- aws codecommit update_pull_request_description
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/update_pull_request_description.html
|
||||
"""
|
||||
if self.boto_client is None:
|
||||
self._connect_boto_client()
|
||||
|
||||
try:
|
||||
self.boto_client.update_pull_request_title(pullRequestId=str(pr_number), title=pr_title)
|
||||
self.boto_client.update_pull_request_description(pullRequestId=str(pr_number), description=pr_body)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
||||
if e.response["Error"]["Code"] == 'InvalidTitleException':
|
||||
raise ValueError(f"Invalid title for PR number: {pr_number}") from e
|
||||
if e.response["Error"]["Code"] == 'InvalidDescriptionException':
|
||||
raise ValueError(f"Invalid description for PR number: {pr_number}") from e
|
||||
if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException':
|
||||
raise ValueError(f"PR is already closed: PR number: {pr_number}") from e
|
||||
raise ValueError(f"Boto3 client error calling publish_description") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error calling publish_description") from e
|
||||
|
||||
def publish_comment(self, repo_name: str, pr_number: int, destination_commit: str, source_commit: str, comment: str, annotation_file: str = None, annotation_line: int = None):
|
||||
"""
|
||||
Publish a comment to a pull request
|
||||
|
||||
Args:
|
||||
- repo_name: name of the repository
|
||||
- pr_number: number of the pull request
|
||||
- destination_commit: The commit hash you want to merge into (the "before" hash) (usually on the main or master branch)
|
||||
- source_commit: The commit hash of the code you are adding (the "after" branch)
|
||||
- comment: The comment you want to publish
|
||||
- annotation_file: The file you want to annotate (optional)
|
||||
- annotation_line: The line number you want to annotate (optional)
|
||||
|
||||
Comment annotations for CodeCommit are different than GitHub.
|
||||
CodeCommit only designates the starting line number for the comment.
|
||||
It does not support the ending line number to highlight a range of lines.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Boto3 Documentation:
|
||||
- aws codecommit post_comment_for_pull_request
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_pull_request.html
|
||||
"""
|
||||
if self.boto_client is None:
|
||||
self._connect_boto_client()
|
||||
|
||||
try:
|
||||
# If the comment has code annotations,
|
||||
# then set the file path and line number in the location dictionary
|
||||
if annotation_file and annotation_line:
|
||||
self.boto_client.post_comment_for_pull_request(
|
||||
pullRequestId=str(pr_number),
|
||||
repositoryName=repo_name,
|
||||
beforeCommitId=destination_commit,
|
||||
afterCommitId=source_commit,
|
||||
content=comment,
|
||||
location={
|
||||
"filePath": annotation_file,
|
||||
"filePosition": annotation_line,
|
||||
"relativeFileVersion": "AFTER",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# The comment does not have code annotations
|
||||
self.boto_client.post_comment_for_pull_request(
|
||||
pullRequestId=str(pr_number),
|
||||
repositoryName=repo_name,
|
||||
beforeCommitId=destination_commit,
|
||||
afterCommitId=source_commit,
|
||||
content=comment,
|
||||
)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||
raise ValueError(f"Repository does not exist: {repo_name}") from e
|
||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
||||
raise ValueError(f"Boto3 client error calling post_comment_for_pull_request") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error calling post_comment_for_pull_request") from e
|
||||
497
apps/utils/pr_agent/git_providers/codecommit_provider.py
Normal file
497
apps/utils/pr_agent/git_providers/codecommit_provider.py
Normal file
@ -0,0 +1,497 @@
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from utils.pr_agent.algo.language_handler import is_valid_file
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from utils.pr_agent.git_providers.codecommit_client import CodeCommitClient
|
||||
|
||||
from ..algo.utils import load_large_diff
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import GitProvider
|
||||
|
||||
|
||||
class PullRequestCCMimic:
|
||||
"""
|
||||
This class mimics the PullRequest class from the PyGithub library for the CodeCommitProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, title: str, diff_files: List[FilePatchInfo]):
|
||||
self.title = title
|
||||
self.diff_files = diff_files
|
||||
self.description = None
|
||||
self.source_commit = None
|
||||
self.source_branch = None # the branch containing your new code changes
|
||||
self.destination_commit = None
|
||||
self.destination_branch = None # the branch you are going to merge into
|
||||
|
||||
|
||||
class CodeCommitFile:
|
||||
"""
|
||||
This class represents a file in a pull request in CodeCommit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
a_path: str,
|
||||
a_blob_id: str,
|
||||
b_path: str,
|
||||
b_blob_id: str,
|
||||
edit_type: EDIT_TYPE,
|
||||
):
|
||||
self.a_path = a_path
|
||||
self.a_blob_id = a_blob_id
|
||||
self.b_path = b_path
|
||||
self.b_blob_id = b_blob_id
|
||||
self.edit_type: EDIT_TYPE = edit_type
|
||||
self.filename = b_path if b_path else a_path
|
||||
|
||||
|
||||
class CodeCommitProvider(GitProvider):
|
||||
"""
|
||||
This class implements the GitProvider interface for AWS CodeCommit repositories.
|
||||
"""
|
||||
|
||||
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||
self.codecommit_client = CodeCommitClient()
|
||||
self.aws_client = None
|
||||
self.repo_name = None
|
||||
self.pr_num = None
|
||||
self.pr = None
|
||||
self.diff_files = None
|
||||
self.git_files = None
|
||||
self.pr_url = pr_url
|
||||
if pr_url:
|
||||
self.set_pr(pr_url)
|
||||
|
||||
def provider_name(self):
|
||||
return "CodeCommit"
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in [
|
||||
"get_issue_comments",
|
||||
"create_inline_comment",
|
||||
"publish_inline_comments",
|
||||
"get_labels",
|
||||
"gfm_markdown"
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_pr(self, pr_url: str):
|
||||
self.repo_name, self.pr_num = self._parse_pr_url(pr_url)
|
||||
self.pr = self._get_pr()
|
||||
|
||||
def get_files(self) -> list[CodeCommitFile]:
|
||||
# bring files from CodeCommit only once
|
||||
if self.git_files:
|
||||
return self.git_files
|
||||
|
||||
self.git_files = []
|
||||
differences = self.codecommit_client.get_differences(self.repo_name, self.pr.destination_commit, self.pr.source_commit)
|
||||
for item in differences:
|
||||
self.git_files.append(CodeCommitFile(item.before_blob_path,
|
||||
item.before_blob_id,
|
||||
item.after_blob_path,
|
||||
item.after_blob_id,
|
||||
CodeCommitProvider._get_edit_type(item.change_type)))
|
||||
return self.git_files
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
"""
|
||||
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in CodeCommit,
|
||||
along with their content and patch information.
|
||||
|
||||
Returns:
|
||||
diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted,
|
||||
or renamed files in the merge request.
|
||||
"""
|
||||
# bring files from CodeCommit only once
|
||||
if self.diff_files:
|
||||
return self.diff_files
|
||||
|
||||
self.diff_files = []
|
||||
|
||||
files = self.get_files()
|
||||
for diff_item in files:
|
||||
patch_filename = ""
|
||||
if diff_item.a_blob_id is not None:
|
||||
patch_filename = diff_item.a_path
|
||||
original_file_content_str = self.codecommit_client.get_file(
|
||||
self.repo_name, diff_item.a_path, self.pr.destination_commit)
|
||||
if isinstance(original_file_content_str, (bytes, bytearray)):
|
||||
original_file_content_str = original_file_content_str.decode("utf-8")
|
||||
else:
|
||||
original_file_content_str = ""
|
||||
|
||||
if diff_item.b_blob_id is not None:
|
||||
patch_filename = diff_item.b_path
|
||||
new_file_content_str = self.codecommit_client.get_file(self.repo_name, diff_item.b_path, self.pr.source_commit)
|
||||
if isinstance(new_file_content_str, (bytes, bytearray)):
|
||||
new_file_content_str = new_file_content_str.decode("utf-8")
|
||||
else:
|
||||
new_file_content_str = ""
|
||||
|
||||
patch = load_large_diff(patch_filename, new_file_content_str, original_file_content_str)
|
||||
|
||||
# Store the diffs as a list of FilePatchInfo objects
|
||||
info = FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
patch,
|
||||
diff_item.b_path,
|
||||
edit_type=diff_item.edit_type,
|
||||
old_filename=None
|
||||
if diff_item.a_path == diff_item.b_path
|
||||
else diff_item.a_path,
|
||||
)
|
||||
# Only add valid files to the diff list
|
||||
# "bad extensions" are set in the language_extensions.toml file
|
||||
# a "valid file" is one that is not in the "bad extensions" list
|
||||
if is_valid_file(info.filename):
|
||||
self.diff_files.append(info)
|
||||
|
||||
return self.diff_files
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
try:
|
||||
self.codecommit_client.publish_description(
|
||||
pr_number=self.pr_num,
|
||||
pr_title=pr_title,
|
||||
pr_body=CodeCommitProvider._add_additional_newlines(pr_body),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit Cannot publish description for PR: {self.pr_num}") from e
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary:
|
||||
get_logger().info(pr_comment)
|
||||
return
|
||||
|
||||
pr_comment = CodeCommitProvider._remove_markdown_html(pr_comment)
|
||||
pr_comment = CodeCommitProvider._add_additional_newlines(pr_comment)
|
||||
|
||||
try:
|
||||
self.codecommit_client.publish_comment(
|
||||
repo_name=self.repo_name,
|
||||
pr_number=self.pr_num,
|
||||
destination_commit=self.pr.destination_commit,
|
||||
source_commit=self.pr.source_commit,
|
||||
comment=pr_comment,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit Cannot publish comment for PR: {self.pr_num}") from e
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
counter = 1
|
||||
for suggestion in code_suggestions:
|
||||
# Verify that each suggestion has the required keys
|
||||
if not all(key in suggestion for key in ["body", "relevant_file", "relevant_lines_start"]):
|
||||
get_logger().warning(f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys")
|
||||
continue
|
||||
|
||||
# Publish the code suggestion to CodeCommit
|
||||
try:
|
||||
get_logger().debug(f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}")
|
||||
self.codecommit_client.publish_comment(
|
||||
repo_name=self.repo_name,
|
||||
pr_number=self.pr_num,
|
||||
destination_commit=self.pr.destination_commit,
|
||||
source_commit=self.pr.source_commit,
|
||||
comment=suggestion["body"],
|
||||
annotation_file=suggestion["relevant_file"],
|
||||
annotation_line=suggestion["relevant_lines_start"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}") from e
|
||||
|
||||
counter += 1
|
||||
|
||||
# The calling function passes in a list of code suggestions, and this function publishes each suggestion one at a time.
|
||||
# If we were to return False here, the calling function will attempt to publish the same list of code suggestions again, one at a time.
|
||||
# Since this function publishes the suggestions one at a time anyway, we always return True here to avoid the retry.
|
||||
return True
|
||||
|
||||
def publish_labels(self, labels):
|
||||
return [""] # not implemented yet
|
||||
|
||||
def get_pr_labels(self, update=False):
|
||||
return [""] # not implemented yet
|
||||
|
||||
def remove_initial_comment(self):
|
||||
return "" # not implemented yet
|
||||
|
||||
def remove_comment(self, comment):
|
||||
return "" # not implemented yet
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html
|
||||
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
|
||||
|
||||
def get_title(self):
|
||||
return self.pr.title
|
||||
|
||||
def get_pr_id(self):
|
||||
"""
|
||||
Returns the PR ID in the format: "repo_name/pr_number".
|
||||
Note: This is an internal identifier for PR-Agent,
|
||||
and is not the same as the CodeCommit PR identifier.
|
||||
"""
|
||||
try:
|
||||
pr_id = f"{self.repo_name}/{self.pr_num}"
|
||||
return pr_id
|
||||
except:
|
||||
return ""
|
||||
|
||||
def get_languages(self):
|
||||
"""
|
||||
Returns a dictionary of languages, containing the percentage of each language used in the PR.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
|
||||
"""
|
||||
commit_files = self.get_files()
|
||||
filenames = [ item.filename for item in commit_files ]
|
||||
extensions = CodeCommitProvider._get_file_extensions(filenames)
|
||||
|
||||
# Calculate the percentage of each file extension in the PR
|
||||
percentages = CodeCommitProvider._get_language_percentages(extensions)
|
||||
|
||||
# The global language_extension_map is a dictionary of languages,
|
||||
# where each dictionary item is a BoxList of extensions.
|
||||
# We want a dictionary of extensions,
|
||||
# where each dictionary item is a language name.
|
||||
# We build that language->extension dictionary here in main_extensions_flat.
|
||||
main_extensions_flat = {}
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||
for language, extensions in language_extension_map.items():
|
||||
for ext in extensions:
|
||||
main_extensions_flat[ext] = language
|
||||
|
||||
# Map the file extension/languages to percentages
|
||||
languages = {}
|
||||
for ext, pct in percentages.items():
|
||||
languages[main_extensions_flat.get(ext, "")] = pct
|
||||
|
||||
return languages
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.pr.source_branch
|
||||
|
||||
def get_pr_description_full(self) -> str:
|
||||
return self.pr.description
|
||||
|
||||
def get_user_id(self):
|
||||
return -1 # not implemented yet
|
||||
|
||||
def get_issue_comments(self):
|
||||
raise NotImplementedError("CodeCommit provider does not support issue comments yet")
|
||||
|
||||
def get_repo_settings(self):
|
||||
# a local ".pr_agent.toml" settings file is optional
|
||||
settings_filename = ".pr_agent.toml"
|
||||
return self.codecommit_client.get_file(self.repo_name, settings_filename, self.pr.source_commit, optional=True)
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
get_logger().info("CodeCommit provider does not support eyes reaction yet")
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
get_logger().info("CodeCommit provider does not support removing reactions yet")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _parse_pr_url(pr_url: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Parse the CodeCommit PR URL and return the repository name and PR number.
|
||||
|
||||
Args:
|
||||
- pr_url: the full AWS CodeCommit pull request URL
|
||||
|
||||
Returns:
|
||||
- Tuple[str, int]: A tuple containing the repository name and PR number.
|
||||
"""
|
||||
# Example PR URL:
|
||||
# https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/__MY_REPO__/pull-requests/123456"
|
||||
parsed_url = urlparse(pr_url)
|
||||
|
||||
if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc):
|
||||
raise ValueError(f"The provided URL is not a valid CodeCommit URL: {pr_url}")
|
||||
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
|
||||
if (
|
||||
len(path_parts) < 6
|
||||
or path_parts[0] != "codesuite"
|
||||
or path_parts[1] != "codecommit"
|
||||
or path_parts[2] != "repositories"
|
||||
or path_parts[4] != "pull-requests"
|
||||
):
|
||||
raise ValueError(f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}")
|
||||
|
||||
repo_name = path_parts[3]
|
||||
|
||||
try:
|
||||
pr_number = int(path_parts[5])
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Unable to convert PR number to integer: '{path_parts[5]}'") from e
|
||||
|
||||
return repo_name, pr_number
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_codecommit_hostname(hostname: str) -> bool:
|
||||
"""
|
||||
Check if the provided hostname is a valid AWS CodeCommit hostname.
|
||||
|
||||
This is not an exhaustive check of AWS region names,
|
||||
but instead uses a regex to check for matching AWS region patterns.
|
||||
|
||||
Args:
|
||||
- hostname: the hostname to check
|
||||
|
||||
Returns:
|
||||
- bool: True if the hostname is valid, False otherwise.
|
||||
"""
|
||||
return re.match(r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname) is not None
|
||||
|
||||
def _get_pr(self):
|
||||
response = self.codecommit_client.get_pr(self.repo_name, self.pr_num)
|
||||
|
||||
if len(response.targets) == 0:
|
||||
raise ValueError(f"No files found in CodeCommit PR: {self.pr_num}")
|
||||
|
||||
# TODO: implement support for multiple targets in one CodeCommit PR
|
||||
# for now, we are only using the first target in the PR
|
||||
if len(response.targets) > 1:
|
||||
get_logger().warning(
|
||||
"Multiple targets in one PR is not supported for CodeCommit yet. Continuing, using the first target only..."
|
||||
)
|
||||
|
||||
# Return our object that mimics PullRequest class from the PyGithub library
|
||||
# (This strategy was copied from the LocalGitProvider)
|
||||
mimic = PullRequestCCMimic(response.title, self.diff_files)
|
||||
mimic.description = response.description
|
||||
mimic.source_commit = response.targets[0].source_commit
|
||||
mimic.source_branch = response.targets[0].source_branch
|
||||
mimic.destination_commit = response.targets[0].destination_commit
|
||||
mimic.destination_branch = response.targets[0].destination_branch
|
||||
|
||||
return mimic
|
||||
|
||||
def get_commit_messages(self):
|
||||
return "" # not implemented yet
|
||||
|
||||
@staticmethod
|
||||
def _add_additional_newlines(body: str) -> str:
|
||||
"""
|
||||
Replace single newlines in a PR body with double newlines.
|
||||
|
||||
CodeCommit Markdown does not seem to render as well as GitHub Markdown,
|
||||
so we add additional newlines to the PR body to make it more readable in CodeCommit.
|
||||
|
||||
Args:
|
||||
- body: the PR body
|
||||
|
||||
Returns:
|
||||
- str: the PR body with the double newlines added
|
||||
"""
|
||||
return re.sub(r'(?<!\n)\n(?!\n)', '\n\n', body)
|
||||
|
||||
@staticmethod
|
||||
def _remove_markdown_html(comment: str) -> str:
|
||||
"""
|
||||
Remove the HTML tags from a PR comment.
|
||||
|
||||
CodeCommit Markdown does not seem to render as well as GitHub Markdown,
|
||||
so we remove the HTML tags from the PR comment to make it more readable in CodeCommit.
|
||||
|
||||
Args:
|
||||
- comment: the PR comment
|
||||
|
||||
Returns:
|
||||
- str: the PR comment with the HTML tags removed
|
||||
"""
|
||||
comment = comment.replace("<details>", "")
|
||||
comment = comment.replace("</details>", "")
|
||||
comment = comment.replace("<summary>", "")
|
||||
comment = comment.replace("</summary>", "")
|
||||
return comment
|
||||
|
||||
@staticmethod
|
||||
def _get_edit_type(codecommit_change_type: str):
|
||||
"""
|
||||
Convert the CodeCommit change type string to the EDIT_TYPE enum.
|
||||
The CodeCommit change type string is returned from the get_differences SDK method.
|
||||
|
||||
Args:
|
||||
- codecommit_change_type: the CodeCommit change type string
|
||||
|
||||
Returns:
|
||||
- An EDIT_TYPE enum representing the modified, added, deleted, or renamed file in the PR diff.
|
||||
"""
|
||||
t = codecommit_change_type.upper()
|
||||
edit_type = None
|
||||
if t == "A":
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif t == "D":
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif t == "M":
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
elif t == "R":
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
return edit_type
|
||||
|
||||
@staticmethod
|
||||
def _get_file_extensions(filenames):
|
||||
"""
|
||||
Return a list of file extensions from a list of filenames.
|
||||
The returned extensions will include the dot "." prefix,
|
||||
to accommodate for the dots in the existing language_extension_map settings.
|
||||
Filenames with no extension will return an empty string for the extension.
|
||||
|
||||
Args:
|
||||
- filenames: a list of filenames
|
||||
|
||||
Returns:
|
||||
- list: A list of file extensions, including the dot "." prefix.
|
||||
"""
|
||||
extensions = []
|
||||
for filename in filenames:
|
||||
filename, ext = os.path.splitext(filename)
|
||||
if ext:
|
||||
extensions.append(ext.lower())
|
||||
else:
|
||||
extensions.append("")
|
||||
return extensions
|
||||
|
||||
@staticmethod
|
||||
def _get_language_percentages(extensions):
|
||||
"""
|
||||
Return a dictionary containing the programming language name (as the key),
|
||||
and the percentage that language is used (as the value),
|
||||
given a list of file extensions.
|
||||
|
||||
Args:
|
||||
- extensions: a list of file extensions
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
|
||||
"""
|
||||
total_files = len(extensions)
|
||||
if total_files == 0:
|
||||
return {}
|
||||
|
||||
# Identify language by file extension and count
|
||||
lang_count = Counter(extensions)
|
||||
# Convert counts to percentages
|
||||
lang_percentage = {
|
||||
lang: round(count / total_files * 100) for lang, count in lang_count.items()
|
||||
}
|
||||
return lang_percentage
|
||||
399
apps/utils/pr_agent/git_providers/gerrit_provider.py
Normal file
399
apps/utils/pr_agent/git_providers/gerrit_provider.py
Normal file
@ -0,0 +1,399 @@
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import subprocess
|
||||
import uuid
|
||||
from collections import Counter, namedtuple
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile, mkdtemp
|
||||
|
||||
import requests
|
||||
import urllib3.util
|
||||
from git import Repo
|
||||
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||
from utils.pr_agent.git_providers.local_git_provider import PullRequestMimic
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
def _call(*command, **kwargs) -> (int, str, str):
|
||||
res = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
**kwargs,
|
||||
)
|
||||
return res.stdout.decode()
|
||||
|
||||
|
||||
def clone(url, directory):
|
||||
get_logger().info("Cloning %s to %s", url, directory)
|
||||
stdout = _call('git', 'clone', "--depth", "1", url, directory)
|
||||
get_logger().info(stdout)
|
||||
|
||||
|
||||
def fetch(url, refspec, cwd):
|
||||
get_logger().info("Fetching %s %s", url, refspec)
|
||||
stdout = _call(
|
||||
'git', 'fetch', '--depth', '2', url, refspec,
|
||||
cwd=cwd
|
||||
)
|
||||
get_logger().info(stdout)
|
||||
|
||||
|
||||
def checkout(cwd):
|
||||
get_logger().info("Checking out")
|
||||
stdout = _call('git', 'checkout', "FETCH_HEAD", cwd=cwd)
|
||||
get_logger().info(stdout)
|
||||
|
||||
|
||||
def show(*args, cwd=None):
|
||||
get_logger().info("Show")
|
||||
return _call('git', 'show', *args, cwd=cwd)
|
||||
|
||||
|
||||
def diff(*args, cwd=None):
|
||||
get_logger().info("Diff")
|
||||
patch = _call('git', 'diff', *args, cwd=cwd)
|
||||
if not patch:
|
||||
get_logger().warning("No changes found")
|
||||
return
|
||||
return patch
|
||||
|
||||
|
||||
def reset_local_changes(cwd):
|
||||
get_logger().info("Reset local changes")
|
||||
_call('git', 'checkout', "--force", cwd=cwd)
|
||||
|
||||
|
||||
def add_comment(url: urllib3.util.Url, refspec, message):
|
||||
*_, patchset, changenum = refspec.rsplit("/")
|
||||
message = "'" + message.replace("'", "'\"'\"'") + "'"
|
||||
return _call(
|
||||
"ssh",
|
||||
"-p", str(url.port),
|
||||
f"{url.auth}@{url.host}",
|
||||
"gerrit", "review",
|
||||
"--message", message,
|
||||
# "--code-review", score,
|
||||
f"{patchset},{changenum}",
|
||||
)
|
||||
|
||||
|
||||
def list_comments(url: urllib3.util.Url, refspec):
|
||||
*_, patchset, _ = refspec.rsplit("/")
|
||||
stdout = _call(
|
||||
"ssh",
|
||||
"-p", str(url.port),
|
||||
f"{url.auth}@{url.host}",
|
||||
"gerrit", "query",
|
||||
"--comments",
|
||||
"--current-patch-set", patchset,
|
||||
"--format", "JSON",
|
||||
)
|
||||
change_set, *_ = stdout.splitlines()
|
||||
return json.loads(change_set)["currentPatchSet"]["comments"]
|
||||
|
||||
|
||||
def prepare_repo(url: urllib3.util.Url, project, refspec):
|
||||
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}")
|
||||
|
||||
directory = pathlib.Path(mkdtemp())
|
||||
clone(repo_url, directory),
|
||||
fetch(repo_url, refspec, cwd=directory)
|
||||
checkout(cwd=directory)
|
||||
return directory
|
||||
|
||||
|
||||
def adopt_to_gerrit_message(message):
|
||||
lines = message.splitlines()
|
||||
buf = []
|
||||
for line in lines:
|
||||
# remove markdown formatting
|
||||
line = (line.replace("*", "")
|
||||
.replace("``", "`")
|
||||
.replace("<details>", "")
|
||||
.replace("</details>", "")
|
||||
.replace("<summary>", "")
|
||||
.replace("</summary>", ""))
|
||||
|
||||
line = line.strip()
|
||||
if line.startswith('#'):
|
||||
buf.append("\n" +
|
||||
line.replace('#', '').removesuffix(":").strip() +
|
||||
":")
|
||||
continue
|
||||
elif line.startswith('-'):
|
||||
buf.append(line.removeprefix('-').strip())
|
||||
continue
|
||||
else:
|
||||
buf.append(line)
|
||||
return "\n".join(buf).strip()
|
||||
|
||||
|
||||
def add_suggestion(src_filename, context: str, start, end: int):
|
||||
with (
|
||||
NamedTemporaryFile("w", delete=False) as tmp,
|
||||
open(src_filename, "r") as src
|
||||
):
|
||||
lines = src.readlines()
|
||||
tmp.writelines(lines[:start - 1])
|
||||
if context:
|
||||
tmp.write(context)
|
||||
tmp.writelines(lines[end:])
|
||||
|
||||
shutil.copy(tmp.name, src_filename)
|
||||
os.remove(tmp.name)
|
||||
|
||||
|
||||
def upload_patch(patch, path):
|
||||
patch_server_endpoint = get_settings().get(
|
||||
'gerrit.patch_server_endpoint')
|
||||
patch_server_token = get_settings().get(
|
||||
'gerrit.patch_server_token')
|
||||
|
||||
response = requests.post(
|
||||
patch_server_endpoint,
|
||||
json={
|
||||
"content": patch,
|
||||
"path": path,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {patch_server_token}",
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
patch_server_endpoint = patch_server_endpoint.rstrip("/")
|
||||
return patch_server_endpoint + "/" + path
|
||||
|
||||
|
||||
class GerritProvider(GitProvider):
|
||||
|
||||
def __init__(self, key: str, incremental=False):
|
||||
self.project, self.refspec = key.split(':')
|
||||
assert self.project, "Project name is required"
|
||||
assert self.refspec, "Refspec is required"
|
||||
base_url = get_settings().get('gerrit.url')
|
||||
assert base_url, "Gerrit URL is required"
|
||||
user = get_settings().get('gerrit.user')
|
||||
assert user, "Gerrit user is required"
|
||||
|
||||
parsed = urllib3.util.parse_url(base_url)
|
||||
self.parsed_url = urllib3.util.parse_url(
|
||||
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
|
||||
)
|
||||
|
||||
self.repo_path = prepare_repo(
|
||||
self.parsed_url, self.project, self.refspec
|
||||
)
|
||||
self.repo = Repo(self.repo_path)
|
||||
assert self.repo
|
||||
self.pr_url = base_url
|
||||
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
||||
|
||||
def get_pr_title(self):
|
||||
"""
|
||||
Substitutes the branch-name as the PR-mimic title.
|
||||
"""
|
||||
return self.repo.branches[0].name
|
||||
|
||||
def get_issue_comments(self):
|
||||
comments = list_comments(self.parsed_url, self.refspec)
|
||||
Comments = namedtuple('Comments', ['reversed'])
|
||||
Comment = namedtuple('Comment', ['body'])
|
||||
return Comments([Comment(c['message']) for c in reversed(comments)])
|
||||
|
||||
def get_pr_labels(self, update=False):
|
||||
raise NotImplementedError(
|
||||
'Getting labels is not implemented for the gerrit provider')
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
|
||||
raise NotImplementedError(
|
||||
'Adding reactions is not implemented for the gerrit provider')
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
|
||||
raise NotImplementedError(
|
||||
'Removing reactions is not implemented for the gerrit provider')
|
||||
|
||||
def get_commit_messages(self):
|
||||
return [self.repo.head.commit.message]
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
with open(self.repo_path / ".pr_agent.toml", 'rb') as f:
|
||||
contents = f.read()
|
||||
return contents
|
||||
except OSError:
|
||||
return b""
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
diffs = self.repo.head.commit.diff(
|
||||
self.repo.head.commit.parents[0], # previous commit
|
||||
create_patch=True,
|
||||
R=True
|
||||
)
|
||||
|
||||
diff_files = []
|
||||
for diff_item in diffs:
|
||||
if diff_item.a_blob is not None:
|
||||
original_file_content_str = (
|
||||
diff_item.a_blob.data_stream.read().decode('utf-8')
|
||||
)
|
||||
else:
|
||||
original_file_content_str = "" # empty file
|
||||
if diff_item.b_blob is not None:
|
||||
new_file_content_str = diff_item.b_blob.data_stream.read(). \
|
||||
decode('utf-8')
|
||||
else:
|
||||
new_file_content_str = "" # empty file
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
if diff_item.new_file:
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif diff_item.deleted_file:
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif diff_item.renamed_file:
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
diff_files.append(
|
||||
FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
diff_item.diff.decode('utf-8'),
|
||||
diff_item.b_path,
|
||||
edit_type=edit_type,
|
||||
old_filename=None
|
||||
if diff_item.a_path == diff_item.b_path
|
||||
else diff_item.a_path
|
||||
)
|
||||
)
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def get_files(self):
|
||||
diff_index = self.repo.head.commit.diff(
|
||||
self.repo.head.commit.parents[0], # previous commit
|
||||
R=True
|
||||
)
|
||||
# Get the list of changed files
|
||||
diff_files = [item.a_path for item in diff_index]
|
||||
return diff_files
|
||||
|
||||
def get_languages(self):
|
||||
"""
|
||||
Calculate percentage of languages in repository. Used for hunk
|
||||
prioritisation.
|
||||
"""
|
||||
# Get all files in repository
|
||||
filepaths = [Path(item.path) for item in
|
||||
self.repo.tree().traverse() if item.type == 'blob']
|
||||
# Identify language by file extension and count
|
||||
lang_count = Counter(
|
||||
ext.lstrip('.') for filepath in filepaths for ext in
|
||||
[filepath.suffix.lower()])
|
||||
# Convert counts to percentages
|
||||
total_files = len(filepaths)
|
||||
lang_percentage = {lang: count / total_files * 100 for lang, count
|
||||
in lang_count.items()}
|
||||
return lang_percentage
|
||||
|
||||
def get_pr_description_full(self):
|
||||
return self.repo.head.commit.message
|
||||
|
||||
def get_user_id(self):
|
||||
return self.repo.head.commit.author.email
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in [
|
||||
# 'get_issue_comments',
|
||||
'create_inline_comment',
|
||||
'publish_inline_comments',
|
||||
'get_labels',
|
||||
'gfm_markdown'
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def split_suggestion(self, msg) -> tuple[str, str]:
|
||||
is_code_context = False
|
||||
description = []
|
||||
context = []
|
||||
for line in msg.splitlines():
|
||||
if line.startswith('```suggestion'):
|
||||
is_code_context = True
|
||||
continue
|
||||
if line.startswith('```'):
|
||||
is_code_context = False
|
||||
continue
|
||||
if is_code_context:
|
||||
context.append(line)
|
||||
else:
|
||||
description.append(
|
||||
line.replace('*', '')
|
||||
)
|
||||
|
||||
return (
|
||||
'\n'.join(description),
|
||||
'\n'.join(context) + '\n' if context else ''
|
||||
)
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list):
|
||||
msg = []
|
||||
for suggestion in code_suggestions:
|
||||
description, code = self.split_suggestion(suggestion['body'])
|
||||
add_suggestion(
|
||||
pathlib.Path(self.repo_path) / suggestion["relevant_file"],
|
||||
code,
|
||||
suggestion["relevant_lines_start"],
|
||||
suggestion["relevant_lines_end"],
|
||||
)
|
||||
patch = diff(cwd=self.repo_path)
|
||||
patch_id = uuid.uuid4().hex[0:4]
|
||||
path = "/".join(["codium-ai", self.refspec, patch_id])
|
||||
full_path = upload_patch(patch, path)
|
||||
reset_local_changes(self.repo_path)
|
||||
msg.append(f'* {description}\n{full_path}')
|
||||
|
||||
if msg:
|
||||
add_comment(self.parsed_url, self.refspec, "\n".join(msg))
|
||||
return True
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if not is_temporary:
|
||||
msg = adopt_to_gerrit_message(pr_comment)
|
||||
add_comment(self.parsed_url, self.refspec, msg)
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
msg = adopt_to_gerrit_message(pr_body)
|
||||
add_comment(self.parsed_url, self.refspec, pr_title + '\n' + msg)
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the gerrit '
|
||||
'provider')
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str,
|
||||
relevant_line_in_file: str, original_suggestion=None):
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the gerrit '
|
||||
'provider')
|
||||
|
||||
|
||||
def publish_labels(self, labels):
|
||||
# Not applicable to the local git provider,
|
||||
# but required by the interface
|
||||
pass
|
||||
|
||||
def remove_initial_comment(self):
|
||||
# remove repo, cloned in previous steps
|
||||
# shutil.rmtree(self.repo_path)
|
||||
pass
|
||||
|
||||
def remove_comment(self, comment):
|
||||
pass
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.repo.head
|
||||
350
apps/utils/pr_agent/git_providers/git_provider.py
Normal file
350
apps/utils/pr_agent/git_providers/git_provider.py
Normal file
@ -0,0 +1,350 @@
|
||||
from abc import ABC, abstractmethod
|
||||
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
||||
from typing import Optional
|
||||
|
||||
from utils.pr_agent.algo.types import FilePatchInfo
|
||||
from utils.pr_agent.algo.utils import Range, process_description
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
MAX_FILES_ALLOWED_FULL = 50
|
||||
|
||||
class GitProvider(ABC):
|
||||
@abstractmethod
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_files(self) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
pass
|
||||
|
||||
def get_incremental_commits(self, is_incremental):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_languages(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pr_branch(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_id(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pr_description_full(self) -> str:
|
||||
pass
|
||||
|
||||
def edit_comment(self, comment, body: str):
|
||||
pass
|
||||
|
||||
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
||||
pass
|
||||
|
||||
def get_comment_body_from_comment_id(self, comment_id: int) -> str:
|
||||
pass
|
||||
|
||||
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
|
||||
pass
|
||||
|
||||
def get_pr_description(self, full: bool = True, split_changes_walkthrough=False) -> str or tuple:
|
||||
from utils.pr_agent.algo.utils import clip_tokens
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||
description = self.get_pr_description_full() if full else self.get_user_description()
|
||||
if split_changes_walkthrough:
|
||||
description, files = process_description(description)
|
||||
if max_tokens_description:
|
||||
description = clip_tokens(description, max_tokens_description)
|
||||
return description, files
|
||||
else:
|
||||
if max_tokens_description:
|
||||
description = clip_tokens(description, max_tokens_description)
|
||||
return description
|
||||
|
||||
def get_user_description(self) -> str:
|
||||
if hasattr(self, 'user_description') and not (self.user_description is None):
|
||||
return self.user_description
|
||||
|
||||
description = (self.get_pr_description_full() or "").strip()
|
||||
description_lowercase = description.lower()
|
||||
get_logger().debug(f"Existing description", description=description_lowercase)
|
||||
|
||||
# if the existing description wasn't generated by the pr-agent, just return it as-is
|
||||
if not self._is_generated_by_pr_agent(description_lowercase):
|
||||
get_logger().info(f"Existing description was not generated by the pr-agent")
|
||||
self.user_description = description
|
||||
return description
|
||||
|
||||
# if the existing description was generated by the pr-agent, but it doesn't contain a user description,
|
||||
# return nothing (empty string) because it means there is no user description
|
||||
user_description_header = "### **user description**"
|
||||
if user_description_header not in description_lowercase:
|
||||
get_logger().info(f"Existing description was generated by the pr-agent, but it doesn't contain a user description")
|
||||
return ""
|
||||
|
||||
# otherwise, extract the original user description from the existing pr-agent description and return it
|
||||
# user_description_start_position = description_lowercase.find(user_description_header) + len(user_description_header)
|
||||
# return description[user_description_start_position:].split("\n", 1)[-1].strip()
|
||||
|
||||
# the 'user description' is in the beginning. extract and return it
|
||||
possible_headers = self._possible_headers()
|
||||
start_position = description_lowercase.find(user_description_header) + len(user_description_header)
|
||||
end_position = len(description)
|
||||
for header in possible_headers: # try to clip at the next header
|
||||
if header != user_description_header and header in description_lowercase:
|
||||
end_position = min(end_position, description_lowercase.find(header))
|
||||
if end_position != len(description) and end_position > start_position:
|
||||
original_user_description = description[start_position:end_position].strip()
|
||||
if original_user_description.endswith("___"):
|
||||
original_user_description = original_user_description[:-3].strip()
|
||||
else:
|
||||
original_user_description = description.split("___")[0].strip()
|
||||
if original_user_description.lower().startswith(user_description_header):
|
||||
original_user_description = original_user_description[len(user_description_header):].strip()
|
||||
|
||||
get_logger().info(f"Extracted user description from existing description",
|
||||
description=original_user_description)
|
||||
self.user_description = original_user_description
|
||||
return original_user_description
|
||||
|
||||
def _possible_headers(self):
|
||||
return ("### **user description**", "### **pr type**", "### **pr description**", "### **pr labels**", "### **type**", "### **description**",
|
||||
"### **labels**", "### 🤖 generated by pr agent")
|
||||
|
||||
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
|
||||
possible_headers = self._possible_headers()
|
||||
return any(description_lowercase.startswith(header) for header in possible_headers)
|
||||
|
||||
@abstractmethod
|
||||
def get_repo_settings(self):
|
||||
pass
|
||||
|
||||
def get_workspace_name(self):
|
||||
return ""
|
||||
|
||||
def get_pr_id(self):
|
||||
return ""
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
return ""
|
||||
|
||||
def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str:
|
||||
return ""
|
||||
|
||||
#### comments operations ####
|
||||
@abstractmethod
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
pass
|
||||
|
||||
def publish_persistent_comment(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
self.publish_comment(pr_comment)
|
||||
|
||||
def publish_persistent_comment_full(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
try:
|
||||
prev_comments = list(self.get_issue_comments())
|
||||
for comment in prev_comments:
|
||||
if comment.body.startswith(initial_header):
|
||||
latest_commit_url = self.get_latest_commit_url()
|
||||
comment_url = self.get_comment_url(comment)
|
||||
if update_header:
|
||||
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
||||
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
|
||||
else:
|
||||
pr_comment_updated = pr_comment
|
||||
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
|
||||
# response = self.mr.notes.update(comment.id, {'body': pr_comment_updated})
|
||||
self.edit_comment(comment, pr_comment_updated)
|
||||
if final_update_message:
|
||||
self.publish_comment(
|
||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
|
||||
return
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||
pass
|
||||
self.publish_comment(pr_comment)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
pass
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
raise NotImplementedError("This git provider does not support creating inline comments yet")
|
||||
|
||||
@abstractmethod
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_initial_comment(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_comment(self, comment):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_issue_comments(self):
|
||||
pass
|
||||
|
||||
def get_comment_url(self, comment) -> str:
|
||||
return ""
|
||||
|
||||
#### labels operations ####
|
||||
@abstractmethod
|
||||
def publish_labels(self, labels):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pr_labels(self, update=False):
|
||||
pass
|
||||
|
||||
def get_repo_labels(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
pass
|
||||
|
||||
#### commits operations ####
|
||||
@abstractmethod
|
||||
def get_commit_messages(self):
|
||||
pass
|
||||
|
||||
def get_pr_url(self) -> str:
|
||||
if hasattr(self, 'pr_url'):
|
||||
return self.pr_url
|
||||
return ""
|
||||
|
||||
def get_latest_commit_url(self) -> str:
|
||||
return ""
|
||||
|
||||
def auto_approve(self) -> bool:
|
||||
return False
|
||||
|
||||
def calc_pr_statistics(self, pull_request_data: dict):
|
||||
return {}
|
||||
|
||||
def get_num_of_files(self):
|
||||
try:
|
||||
return len(self.get_diff_files())
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
def limit_output_characters(self, output: str, max_chars: int):
|
||||
return output[:max_chars] + '...' if len(output) > max_chars else output
|
||||
|
||||
|
||||
def get_main_pr_language(languages, files) -> str:
|
||||
"""
|
||||
Get the main language of the commit. Return an empty string if cannot determine.
|
||||
"""
|
||||
main_language_str = ""
|
||||
if not languages:
|
||||
get_logger().info("No languages detected")
|
||||
return main_language_str
|
||||
if not files:
|
||||
get_logger().info("No files in diff")
|
||||
return main_language_str
|
||||
|
||||
try:
|
||||
top_language = max(languages, key=languages.get).lower()
|
||||
|
||||
# validate that the specific commit uses the main language
|
||||
extension_list = []
|
||||
for file in files:
|
||||
if not file:
|
||||
continue
|
||||
if isinstance(file, str):
|
||||
file = FilePatchInfo(base_file=None, head_file=None, patch=None, filename=file)
|
||||
extension_list.append(file.filename.rsplit('.')[-1])
|
||||
|
||||
# get the most common extension
|
||||
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
|
||||
try:
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||
|
||||
if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]:
|
||||
main_language_str = top_language
|
||||
else:
|
||||
for language, extensions in language_extension_map.items():
|
||||
if most_common_extension in extensions:
|
||||
main_language_str = language
|
||||
break
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to get main language: {e}")
|
||||
pass
|
||||
|
||||
## old approach:
|
||||
# most_common_extension = max(set(extension_list), key=extension_list.count)
|
||||
# if most_common_extension == 'py' and top_language == 'python' or \
|
||||
# most_common_extension == 'js' and top_language == 'javascript' or \
|
||||
# most_common_extension == 'ts' and top_language == 'typescript' or \
|
||||
# most_common_extension == 'tsx' and top_language == 'typescript' or \
|
||||
# most_common_extension == 'go' and top_language == 'go' or \
|
||||
# most_common_extension == 'java' and top_language == 'java' or \
|
||||
# most_common_extension == 'c' and top_language == 'c' or \
|
||||
# most_common_extension == 'cpp' and top_language == 'c++' or \
|
||||
# most_common_extension == 'cs' and top_language == 'c#' or \
|
||||
# most_common_extension == 'swift' and top_language == 'swift' or \
|
||||
# most_common_extension == 'php' and top_language == 'php' or \
|
||||
# most_common_extension == 'rb' and top_language == 'ruby' or \
|
||||
# most_common_extension == 'rs' and top_language == 'rust' or \
|
||||
# most_common_extension == 'scala' and top_language == 'scala' or \
|
||||
# most_common_extension == 'kt' and top_language == 'kotlin' or \
|
||||
# most_common_extension == 'pl' and top_language == 'perl' or \
|
||||
# most_common_extension == top_language:
|
||||
# main_language_str = top_language
|
||||
|
||||
except Exception as e:
|
||||
get_logger().exception(e)
|
||||
pass
|
||||
|
||||
return main_language_str
|
||||
|
||||
|
||||
|
||||
|
||||
class IncrementalPR:
|
||||
def __init__(self, is_incremental: bool = False):
|
||||
self.is_incremental = is_incremental
|
||||
self.commits_range = None
|
||||
self.first_new_commit = None
|
||||
self.last_seen_commit = None
|
||||
|
||||
@property
|
||||
def first_new_commit_sha(self):
|
||||
return None if self.first_new_commit is None else self.first_new_commit.sha
|
||||
|
||||
@property
|
||||
def last_seen_commit_sha(self):
|
||||
return None if self.last_seen_commit is None else self.last_seen_commit.sha
|
||||
1066
apps/utils/pr_agent/git_providers/github_provider.py
Normal file
1066
apps/utils/pr_agent/git_providers/github_provider.py
Normal file
File diff suppressed because it is too large
Load Diff
591
apps/utils/pr_agent/git_providers/gitlab_provider.py
Normal file
591
apps/utils/pr_agent/git_providers/gitlab_provider.py
Normal file
@ -0,0 +1,591 @@
|
||||
import difflib
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gitlab
|
||||
from gitlab import GitlabGetError
|
||||
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
from ..algo.file_filter import filter_ignored
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import (clip_tokens,
|
||||
find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff)
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
|
||||
|
||||
|
||||
class DiffNotFoundError(Exception):
|
||||
"""Raised when the diff for a merge request cannot be found."""
|
||||
pass
|
||||
|
||||
class GitLabProvider(GitProvider):
|
||||
|
||||
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||
if not gitlab_url:
|
||||
raise ValueError("GitLab URL is not set in the config file")
|
||||
self.gitlab_url = gitlab_url
|
||||
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||
if not gitlab_access_token:
|
||||
raise ValueError("GitLab personal access token is not set in the config file")
|
||||
self.gl = gitlab.Gitlab(
|
||||
url=gitlab_url,
|
||||
oauth_token=gitlab_access_token
|
||||
)
|
||||
self.max_comment_chars = 65000
|
||||
self.id_project = None
|
||||
self.id_mr = None
|
||||
self.mr = None
|
||||
self.diff_files = None
|
||||
self.git_files = None
|
||||
self.temp_comments = []
|
||||
self.pr_url = merge_request_url
|
||||
self._set_merge_request(merge_request_url)
|
||||
self.RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
self.incremental = incremental
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments',
|
||||
'publish_file_comments']: # gfm_markdown is supported in gitlab !
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def pr(self):
|
||||
'''The GitLab terminology is merge request (MR) instead of pull request (PR)'''
|
||||
return self.mr
|
||||
|
||||
def _set_merge_request(self, merge_request_url: str):
|
||||
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
|
||||
self.mr = self._get_merge_request()
|
||||
try:
|
||||
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
|
||||
except IndexError as e:
|
||||
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
||||
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") from e
|
||||
|
||||
|
||||
def get_pr_file_content(self, file_path: str, branch: str) -> str:
|
||||
try:
|
||||
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
|
||||
except GitlabGetError:
|
||||
# In case of file creation the method returns GitlabGetError (404 file not found).
|
||||
# In this case we return an empty string for the diff.
|
||||
return ''
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
"""
|
||||
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitLab,
|
||||
along with their content and patch information.
|
||||
|
||||
Returns:
|
||||
diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted,
|
||||
or renamed files in the merge request.
|
||||
"""
|
||||
|
||||
if self.diff_files:
|
||||
return self.diff_files
|
||||
|
||||
# filter files using [ignore] patterns
|
||||
diffs_original = self.mr.changes()['changes']
|
||||
diffs = filter_ignored(diffs_original, 'gitlab')
|
||||
if diffs != diffs_original:
|
||||
try:
|
||||
names_original = [diff['new_path'] for diff in diffs_original]
|
||||
names_filtered = [diff['new_path'] for diff in diffs]
|
||||
get_logger().info(f"Filtered out [ignore] files for merge request {self.id_mr}", extra={
|
||||
'original_files': names_original,
|
||||
'filtered_files': names_filtered
|
||||
})
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
diff_files = []
|
||||
invalid_files_names = []
|
||||
counter_valid = 0
|
||||
for diff in diffs:
|
||||
if not is_valid_file(diff['new_path']):
|
||||
invalid_files_names.append(diff['new_path'])
|
||||
continue
|
||||
|
||||
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
|
||||
counter_valid += 1
|
||||
if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']:
|
||||
original_file_content_str = self.get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha'])
|
||||
new_file_content_str = self.get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha'])
|
||||
else:
|
||||
if counter_valid == MAX_FILES_ALLOWED_FULL:
|
||||
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
|
||||
original_file_content_str = ''
|
||||
new_file_content_str = ''
|
||||
|
||||
try:
|
||||
if isinstance(original_file_content_str, bytes):
|
||||
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
|
||||
if isinstance(new_file_content_str, bytes):
|
||||
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
|
||||
except UnicodeDecodeError:
|
||||
get_logger().warning(
|
||||
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
|
||||
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
if diff['new_file']:
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif diff['deleted_file']:
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif diff['renamed_file']:
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
|
||||
filename = diff['new_path']
|
||||
patch = diff['diff']
|
||||
if not patch:
|
||||
patch = load_large_diff(filename, new_file_content_str, original_file_content_str)
|
||||
|
||||
|
||||
# count number of lines added and removed
|
||||
patch_lines = patch.splitlines(keepends=True)
|
||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
||||
diff_files.append(
|
||||
FilePatchInfo(original_file_content_str, new_file_content_str,
|
||||
patch=patch,
|
||||
filename=filename,
|
||||
edit_type=edit_type,
|
||||
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'],
|
||||
num_plus_lines=num_plus_lines,
|
||||
num_minus_lines=num_minus_lines, ))
|
||||
if invalid_files_names:
|
||||
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
|
||||
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def get_files(self) -> list:
|
||||
if not self.git_files:
|
||||
self.git_files = [change['new_path'] for change in self.mr.changes()['changes']]
|
||||
return self.git_files
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
try:
|
||||
self.mr.title = pr_title
|
||||
self.mr.description = pr_body
|
||||
self.mr.save()
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Could not update merge request {self.id_mr} description: {e}")
|
||||
|
||||
def get_latest_commit_url(self):
|
||||
return self.mr.commits().next().web_url
|
||||
|
||||
def get_comment_url(self, comment):
|
||||
return f"{self.mr.web_url}#note_{comment.id}"
|
||||
|
||||
def publish_persistent_comment(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
|
||||
|
||||
def publish_comment(self, mr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {mr_comment}")
|
||||
return None
|
||||
mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars)
|
||||
comment = self.mr.notes.create({'body': mr_comment})
|
||||
if is_temporary:
|
||||
self.temp_comments.append(comment)
|
||||
return comment
|
||||
|
||||
def edit_comment(self, comment, body: str):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
self.mr.notes.update(comment.id,{'body': body} )
|
||||
|
||||
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
comment = self.mr.notes.get(comment_id)
|
||||
comment.body = body
|
||||
comment.save()
|
||||
|
||||
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
discussion = self.mr.discussions.get(comment_id)
|
||||
discussion.notes.create({'body': body})
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
|
||||
relevant_line_in_file)
|
||||
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
||||
target_file, target_line_no, original_suggestion)
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None):
|
||||
raise NotImplementedError("Gitlab provider does not support creating inline comments yet")
|
||||
|
||||
def create_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet")
|
||||
|
||||
def get_comment_body_from_comment_id(self, comment_id: int):
|
||||
comment = self.mr.notes.get(comment_id).body
|
||||
return comment
|
||||
|
||||
def send_inline_comment(self, body: str, edit_type: str, found: bool, relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
source_line_no: int, target_file: str, target_line_no: int,
|
||||
original_suggestion=None) -> None:
|
||||
if not found:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
else:
|
||||
# in order to have exact sha's we have to find correct diff for this change
|
||||
diff = self.get_relevant_diff(relevant_file, relevant_line_in_file)
|
||||
if diff is None:
|
||||
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
||||
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}")
|
||||
pos_obj = {'position_type': 'text',
|
||||
'new_path': target_file.filename,
|
||||
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
||||
'base_sha': diff.base_commit_sha, 'start_sha': diff.start_commit_sha, 'head_sha': diff.head_commit_sha}
|
||||
if edit_type == 'deletion':
|
||||
pos_obj['old_line'] = source_line_no - 1
|
||||
elif edit_type == 'addition':
|
||||
pos_obj['new_line'] = target_line_no - 1
|
||||
else:
|
||||
pos_obj['new_line'] = target_line_no - 1
|
||||
pos_obj['old_line'] = source_line_no - 1
|
||||
get_logger().debug(f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}")
|
||||
try:
|
||||
self.mr.discussions.create({'body': body, 'position': pos_obj})
|
||||
except Exception as e:
|
||||
try:
|
||||
# fallback - create a general note on the file in the MR
|
||||
if 'suggestion_orig_location' in original_suggestion:
|
||||
line_start = original_suggestion['suggestion_orig_location']['start_line']
|
||||
line_end = original_suggestion['suggestion_orig_location']['end_line']
|
||||
old_code_snippet = original_suggestion['prev_code_snippet']
|
||||
new_code_snippet = original_suggestion['new_code_snippet']
|
||||
content = original_suggestion['suggestion_summary']
|
||||
label = original_suggestion['category']
|
||||
if 'score' in original_suggestion:
|
||||
score = original_suggestion['score']
|
||||
else:
|
||||
score = 7
|
||||
else:
|
||||
line_start = original_suggestion['relevant_lines_start']
|
||||
line_end = original_suggestion['relevant_lines_end']
|
||||
old_code_snippet = original_suggestion['existing_code']
|
||||
new_code_snippet = original_suggestion['improved_code']
|
||||
content = original_suggestion['suggestion_content']
|
||||
label = original_suggestion['label']
|
||||
score = original_suggestion.get('score', 7)
|
||||
|
||||
if hasattr(self, 'main_language'):
|
||||
language = self.main_language
|
||||
else:
|
||||
language = ''
|
||||
link = self.get_line_link(relevant_file, line_start, line_end)
|
||||
body_fallback =f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
|
||||
body_fallback +=f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
|
||||
body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`"
|
||||
body_fallback+="</details>\n\n"
|
||||
diff_patch = difflib.unified_diff(old_code_snippet.split('\n'),
|
||||
new_code_snippet.split('\n'), n=999)
|
||||
patch_orig = "\n".join(diff_patch)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||
body_fallback += diff_code
|
||||
|
||||
# Create a general note on the file in the MR
|
||||
self.mr.notes.create({
|
||||
'body': body_fallback,
|
||||
'position': {
|
||||
'base_sha': diff.base_commit_sha,
|
||||
'start_sha': diff.start_commit_sha,
|
||||
'head_sha': diff.head_commit_sha,
|
||||
'position_type': 'text',
|
||||
'file_path': f'{target_file.filename}',
|
||||
}
|
||||
})
|
||||
get_logger().debug(f"Created fallback comment in MR {self.id_mr} with position {pos_obj}")
|
||||
|
||||
# get_logger().debug(
|
||||
# f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)")
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to create comment in MR {self.id_mr}")
|
||||
|
||||
def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: str) -> Optional[dict]:
|
||||
changes = self.mr.changes() # Retrieve the changes for the merge request once
|
||||
if not changes:
|
||||
get_logger().error('No changes found for the merge request.')
|
||||
return None
|
||||
all_diffs = self.mr.diffs.list(get_all=True)
|
||||
if not all_diffs:
|
||||
get_logger().error('No diffs found for the merge request.')
|
||||
return None
|
||||
for diff in all_diffs:
|
||||
for change in changes['changes']:
|
||||
if change['new_path'] == relevant_file and relevant_line_in_file in change['diff']:
|
||||
return diff
|
||||
get_logger().debug(
|
||||
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
|
||||
return self.last_diff # fallback to last_diff if no relevant diff is found
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
for suggestion in code_suggestions:
|
||||
try:
|
||||
if suggestion and 'original_suggestion' in suggestion:
|
||||
original_suggestion = suggestion['original_suggestion']
|
||||
else:
|
||||
original_suggestion = suggestion
|
||||
body = suggestion['body']
|
||||
relevant_file = suggestion['relevant_file']
|
||||
relevant_lines_start = suggestion['relevant_lines_start']
|
||||
relevant_lines_end = suggestion['relevant_lines_end']
|
||||
|
||||
diff_files = self.get_diff_files()
|
||||
target_file = None
|
||||
for file in diff_files:
|
||||
if file.filename == relevant_file:
|
||||
if file.filename == relevant_file:
|
||||
target_file = file
|
||||
break
|
||||
range = relevant_lines_end - relevant_lines_start # no need to add 1
|
||||
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
|
||||
lines = target_file.head_file.splitlines()
|
||||
relevant_line_in_file = lines[relevant_lines_start - 1]
|
||||
|
||||
# edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(target_file,
|
||||
# relevant_line_in_file)
|
||||
# for code suggestions, we want to edit the new code
|
||||
source_line_no = -1
|
||||
target_line_no = relevant_lines_start + 1
|
||||
found = True
|
||||
edit_type = 'addition'
|
||||
|
||||
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
||||
target_file, target_line_no, original_suggestion)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}")
|
||||
|
||||
# note that we publish suggestions one-by-one. so, if one fails, the rest will still be published
|
||||
return True
|
||||
|
||||
def publish_file_comments(self, file_comments: list) -> bool:
|
||||
pass
|
||||
|
||||
def search_line(self, relevant_file, relevant_line_in_file):
|
||||
target_file = None
|
||||
|
||||
edit_type = self.get_edit_type(relevant_line_in_file)
|
||||
for file in self.get_diff_files():
|
||||
if file.filename == relevant_file:
|
||||
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
|
||||
relevant_line_in_file)
|
||||
return edit_type, found, source_line_no, target_file, target_line_no
|
||||
|
||||
def find_in_file(self, file, relevant_line_in_file):
|
||||
edit_type = 'context'
|
||||
source_line_no = 0
|
||||
target_line_no = 0
|
||||
found = False
|
||||
target_file = file
|
||||
patch = file.patch
|
||||
patch_lines = patch.splitlines()
|
||||
for line in patch_lines:
|
||||
if line.startswith('@@'):
|
||||
match = self.RE_HUNK_HEADER.match(line)
|
||||
if not match:
|
||||
continue
|
||||
start_old, size_old, start_new, size_new, _ = match.groups()
|
||||
source_line_no = int(start_old)
|
||||
target_line_no = int(start_new)
|
||||
continue
|
||||
if line.startswith('-'):
|
||||
source_line_no += 1
|
||||
elif line.startswith('+'):
|
||||
target_line_no += 1
|
||||
elif line.startswith(' '):
|
||||
source_line_no += 1
|
||||
target_line_no += 1
|
||||
if relevant_line_in_file in line:
|
||||
found = True
|
||||
edit_type = self.get_edit_type(line)
|
||||
break
|
||||
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:].lstrip() in line:
|
||||
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||
# it's a context line
|
||||
found = True
|
||||
edit_type = self.get_edit_type(line)
|
||||
break
|
||||
return edit_type, found, source_line_no, target_file, target_line_no
|
||||
|
||||
def get_edit_type(self, relevant_line_in_file):
|
||||
edit_type = 'context'
|
||||
if relevant_line_in_file[0] == '-':
|
||||
edit_type = 'deletion'
|
||||
elif relevant_line_in_file[0] == '+':
|
||||
edit_type = 'addition'
|
||||
return edit_type
|
||||
|
||||
def remove_initial_comment(self):
|
||||
try:
|
||||
for comment in self.temp_comments:
|
||||
self.remove_comment(comment)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove temp comments, error: {e}")
|
||||
|
||||
def remove_comment(self, comment):
|
||||
try:
|
||||
comment.delete()
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove comment, error: {e}")
|
||||
|
||||
def get_title(self):
|
||||
return self.mr.title
|
||||
|
||||
def get_languages(self):
|
||||
languages = self.gl.projects.get(self.id_project).languages()
|
||||
return languages
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.mr.source_branch
|
||||
|
||||
def get_pr_owner_id(self) -> str | None:
|
||||
if not self.gitlab_url or 'gitlab.com' in self.gitlab_url:
|
||||
if not self.id_project:
|
||||
return None
|
||||
return self.id_project.split('/')[0]
|
||||
# extract host name
|
||||
host = urlparse(self.gitlab_url).hostname
|
||||
return host
|
||||
|
||||
def get_pr_description_full(self):
|
||||
return self.mr.description
|
||||
|
||||
def get_issue_comments(self):
|
||||
return self.mr.notes.list(get_all=True)[::-1]
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch).decode()
|
||||
return contents
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def get_workspace_name(self):
|
||||
return self.id_project.split('/')[0]
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
return True
|
||||
|
||||
def _parse_merge_request_url(self, merge_request_url: str) -> Tuple[str, int]:
|
||||
parsed_url = urlparse(merge_request_url)
|
||||
|
||||
path_parts = parsed_url.path.strip('/').split('/')
|
||||
if 'merge_requests' not in path_parts:
|
||||
raise ValueError("The provided URL does not appear to be a GitLab merge request URL")
|
||||
|
||||
mr_index = path_parts.index('merge_requests')
|
||||
# Ensure there is an ID after 'merge_requests'
|
||||
if len(path_parts) <= mr_index + 1:
|
||||
raise ValueError("The provided URL does not contain a merge request ID")
|
||||
|
||||
try:
|
||||
mr_id = int(path_parts[mr_index + 1])
|
||||
except ValueError as e:
|
||||
raise ValueError("Unable to convert merge request ID to integer") from e
|
||||
|
||||
# Handle special delimiter (-)
|
||||
project_path = "/".join(path_parts[:mr_index])
|
||||
if project_path.endswith('/-'):
|
||||
project_path = project_path[:-2]
|
||||
|
||||
# Return the path before 'merge_requests' and the ID
|
||||
return project_path, mr_id
|
||||
|
||||
def _get_merge_request(self):
|
||||
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
|
||||
return mr
|
||||
|
||||
def get_user_id(self):
|
||||
return None
|
||||
|
||||
def publish_labels(self, pr_types):
|
||||
try:
|
||||
self.mr.labels = list(set(pr_types))
|
||||
self.mr.save()
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to publish labels, error: {e}")
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
pass
|
||||
|
||||
def get_pr_labels(self, update=False):
|
||||
return self.mr.labels
|
||||
|
||||
def get_repo_labels(self):
|
||||
return self.gl.projects.get(self.id_project).labels.list()
|
||||
|
||||
def get_commit_messages(self):
|
||||
"""
|
||||
Retrieves the commit messages of a pull request.
|
||||
|
||||
Returns:
|
||||
str: A string containing the commit messages of the pull request.
|
||||
"""
|
||||
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
|
||||
try:
|
||||
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
|
||||
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
|
||||
except Exception:
|
||||
commit_messages_str = ""
|
||||
if max_tokens:
|
||||
commit_messages_str = clip_tokens(commit_messages_str, max_tokens)
|
||||
return commit_messages_str
|
||||
|
||||
def get_pr_id(self):
|
||||
try:
|
||||
pr_id = self.mr.web_url
|
||||
return pr_id
|
||||
except:
|
||||
return ""
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
if relevant_line_start == -1:
|
||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads"
|
||||
elif relevant_line_end:
|
||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}-{relevant_line_end}"
|
||||
else:
|
||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}"
|
||||
return link
|
||||
|
||||
|
||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||
try:
|
||||
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
|
||||
relevant_line_str = suggestion['relevant_line'].rstrip()
|
||||
if not relevant_line_str:
|
||||
return ""
|
||||
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
||||
(self.diff_files, relevant_file, relevant_line_str)
|
||||
|
||||
if absolute_position != -1:
|
||||
# link to right file only
|
||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{absolute_position}"
|
||||
|
||||
# # link to diff
|
||||
# sha_file = hashlib.sha1(relevant_file.encode('utf-8')).hexdigest()
|
||||
# link = f"{self.pr.web_url}/diffs#{sha_file}_{absolute_position}_{absolute_position}"
|
||||
return link
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link, error: {e}")
|
||||
|
||||
return ""
|
||||
192
apps/utils/pr_agent/git_providers/local_git_provider.py
Normal file
192
apps/utils/pr_agent/git_providers/local_git_provider.py
Normal file
@ -0,0 +1,192 @@
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from git import Repo
|
||||
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from utils.pr_agent.config_loader import _find_repository_root, get_settings
|
||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PullRequestMimic:
|
||||
"""
|
||||
This class mimics the PullRequest class from the PyGithub library for the LocalGitProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, title: str, diff_files: List[FilePatchInfo]):
|
||||
self.title = title
|
||||
self.diff_files = diff_files
|
||||
|
||||
|
||||
class LocalGitProvider(GitProvider):
|
||||
"""
|
||||
This class implements the GitProvider interface for local git repositories.
|
||||
It mimics the PR functionality of the GitProvider interface,
|
||||
but does not require a hosted git repository.
|
||||
Instead of providing a PR url, the user provides a local branch path to generate a diff-patch.
|
||||
For the MVP it only supports the /review and /describe capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, target_branch_name, incremental=False):
|
||||
self.repo_path = _find_repository_root()
|
||||
if self.repo_path is None:
|
||||
raise ValueError('Could not find repository root')
|
||||
self.repo = Repo(self.repo_path)
|
||||
self.head_branch_name = self.repo.head.ref.name
|
||||
self.target_branch_name = target_branch_name
|
||||
self._prepare_repo()
|
||||
self.diff_files = None
|
||||
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
||||
self.description_path = get_settings().get('local.description_path') \
|
||||
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
|
||||
self.review_path = get_settings().get('local.review_path') \
|
||||
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
|
||||
# inline code comments are not supported for local git repositories
|
||||
get_settings().pr_reviewer.inline_code_comments = False
|
||||
|
||||
def _prepare_repo(self):
|
||||
"""
|
||||
Prepare the repository for PR-mimic generation.
|
||||
"""
|
||||
get_logger().debug('Preparing repository for PR-mimic generation...')
|
||||
if self.repo.is_dirty():
|
||||
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.')
|
||||
if self.target_branch_name not in self.repo.heads:
|
||||
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels',
|
||||
'gfm_markdown']:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
diffs = self.repo.head.commit.diff(
|
||||
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
|
||||
create_patch=True,
|
||||
R=True
|
||||
)
|
||||
diff_files = []
|
||||
for diff_item in diffs:
|
||||
if diff_item.a_blob is not None:
|
||||
original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8')
|
||||
else:
|
||||
original_file_content_str = "" # empty file
|
||||
if diff_item.b_blob is not None:
|
||||
new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8')
|
||||
else:
|
||||
new_file_content_str = "" # empty file
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
if diff_item.new_file:
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif diff_item.deleted_file:
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif diff_item.renamed_file:
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
diff_files.append(
|
||||
FilePatchInfo(original_file_content_str,
|
||||
new_file_content_str,
|
||||
diff_item.diff.decode('utf-8'),
|
||||
diff_item.b_path,
|
||||
edit_type=edit_type,
|
||||
old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path
|
||||
)
|
||||
)
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def get_files(self) -> List[str]:
|
||||
"""
|
||||
Returns a list of files with changes in the diff.
|
||||
"""
|
||||
diff_index = self.repo.head.commit.diff(
|
||||
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
|
||||
R=True
|
||||
)
|
||||
# Get the list of changed files
|
||||
diff_files = [item.a_path for item in diff_index]
|
||||
return diff_files
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
with open(self.description_path, "w") as file:
|
||||
# Write the string to the file
|
||||
file.write(pr_title + '\n' + pr_body)
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
with open(self.review_path, "w") as file:
|
||||
# Write the string to the file
|
||||
file.write(pr_comment)
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
|
||||
|
||||
def publish_code_suggestion(self, body: str, relevant_file: str,
|
||||
relevant_lines_start: int, relevant_lines_end: int):
|
||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
||||
|
||||
def publish_labels(self, labels):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
|
||||
def remove_initial_comment(self):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
|
||||
def remove_comment(self, comment):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
|
||||
def add_eyes_reaction(self, comment):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
|
||||
def get_commit_messages(self):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
|
||||
def get_repo_settings(self):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
|
||||
def remove_reaction(self, comment):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
|
||||
def get_languages(self):
|
||||
"""
|
||||
Calculate percentage of languages in repository. Used for hunk prioritisation.
|
||||
"""
|
||||
# Get all files in repository
|
||||
filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob']
|
||||
# Identify language by file extension and count
|
||||
lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()])
|
||||
# Convert counts to percentages
|
||||
total_files = len(filepaths)
|
||||
lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()}
|
||||
return lang_percentage
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.repo.head
|
||||
|
||||
def get_user_id(self):
|
||||
return -1 # Not used anywhere for the local provider, but required by the interface
|
||||
|
||||
def get_pr_description_full(self):
|
||||
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
||||
# Get the commit messages and concatenate
|
||||
commit_messages = " ".join([commit.message for commit in commits_diff])
|
||||
# TODO Handle the description better - maybe use gpt-3.5 summarisation here?
|
||||
return commit_messages[:200] # Use max 200 characters
|
||||
|
||||
def get_pr_title(self):
|
||||
"""
|
||||
Substitutes the branch-name as the PR-mimic title.
|
||||
"""
|
||||
return self.head_branch_name
|
||||
|
||||
def get_issue_comments(self):
|
||||
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
|
||||
|
||||
def get_pr_labels(self, update=False):
|
||||
raise NotImplementedError('Getting labels is not implemented for the local git provider')
|
||||
102
apps/utils/pr_agent/git_providers/utils.py
Normal file
102
apps/utils/pr_agent/git_providers/utils.py
Normal file
@ -0,0 +1,102 @@
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from dynaconf import Dynaconf
|
||||
from starlette_context import context
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import (get_git_provider_with_context)
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
def apply_repo_settings(pr_url):
|
||||
git_provider = get_git_provider_with_context(pr_url)
|
||||
if get_settings().config.use_repo_settings_file:
|
||||
repo_settings_file = None
|
||||
try:
|
||||
try:
|
||||
repo_settings = context.get("repo_settings", None)
|
||||
except Exception:
|
||||
repo_settings = None
|
||||
pass
|
||||
if repo_settings is None: # None is different from "", which is a valid value
|
||||
repo_settings = git_provider.get_repo_settings()
|
||||
try:
|
||||
context["repo_settings"] = repo_settings
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
error_local = None
|
||||
if repo_settings:
|
||||
repo_settings_file = None
|
||||
category = 'local'
|
||||
try:
|
||||
fd, repo_settings_file = tempfile.mkstemp(suffix='.toml')
|
||||
os.write(fd, repo_settings)
|
||||
new_settings = Dynaconf(settings_files=[repo_settings_file])
|
||||
for section, contents in new_settings.as_dict().items():
|
||||
section_dict = copy.deepcopy(get_settings().as_dict().get(section, {}))
|
||||
for key, value in contents.items():
|
||||
section_dict[key] = value
|
||||
get_settings().unset(section)
|
||||
get_settings().set(section, section_dict, merge=False)
|
||||
get_logger().info(f"Applying repo settings:\n{new_settings.as_dict()}")
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to apply repo {category} settings, error: {str(e)}")
|
||||
error_local = {'error': str(e), 'settings': repo_settings, 'category': category}
|
||||
|
||||
if error_local:
|
||||
handle_configurations_errors([error_local], git_provider)
|
||||
except Exception as e:
|
||||
get_logger().exception("Failed to apply repo settings", e)
|
||||
finally:
|
||||
if repo_settings_file:
|
||||
try:
|
||||
os.remove(repo_settings_file)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e)
|
||||
|
||||
# enable switching models with a short definition
|
||||
if get_settings().config.model.lower() == 'claude-3-5-sonnet':
|
||||
set_claude_model()
|
||||
|
||||
|
||||
def handle_configurations_errors(config_errors, git_provider):
|
||||
try:
|
||||
if not any(config_errors):
|
||||
return
|
||||
|
||||
for err in config_errors:
|
||||
if err:
|
||||
configuration_file_content = err['settings'].decode()
|
||||
err_message = err['error']
|
||||
config_type = err['category']
|
||||
header = f"❌ **PR-Agent failed to apply '{config_type}' repo settings**"
|
||||
body = f"{header}\n\nThe configuration file needs to be a valid [TOML](https://qodo-merge-docs.qodo.ai/usage-guide/configuration_options/), please fix it.\n\n"
|
||||
body += f"___\n\n**Error message:**\n`{err_message}`\n\n"
|
||||
if git_provider.is_supported("gfm_markdown"):
|
||||
body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>"
|
||||
else:
|
||||
body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n"
|
||||
get_logger().warning(f"Sending a 'configuration error' comment to the PR", artifact={'body': body})
|
||||
# git_provider.publish_comment(body)
|
||||
if hasattr(git_provider, 'publish_persistent_comment'):
|
||||
git_provider.publish_persistent_comment(body,
|
||||
initial_header=header,
|
||||
update_header=False,
|
||||
final_update_message=False)
|
||||
else:
|
||||
git_provider.publish_comment(body)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to handle configurations errors", e)
|
||||
|
||||
|
||||
def set_claude_model():
|
||||
"""
|
||||
set the claude-sonnet-3.5 model easily (even by users), just by stating: --config.model='claude-3-5-sonnet'
|
||||
"""
|
||||
model_claude = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
get_settings().set('config.model', model_claude)
|
||||
get_settings().set('config.model_weak', model_claude)
|
||||
get_settings().set('config.fallback_models', [model_claude])
|
||||
14
apps/utils/pr_agent/identity_providers/__init__.py
Normal file
14
apps/utils/pr_agent/identity_providers/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.identity_providers.default_identity_provider import \
|
||||
DefaultIdentityProvider
|
||||
|
||||
_IDENTITY_PROVIDERS = {
|
||||
'default': DefaultIdentityProvider
|
||||
}
|
||||
|
||||
|
||||
def get_identity_provider():
|
||||
identity_provider_id = get_settings().get("CONFIG.IDENTITY_PROVIDER", "default")
|
||||
if identity_provider_id not in _IDENTITY_PROVIDERS:
|
||||
raise ValueError(f"Unknown identity provider: {identity_provider_id}")
|
||||
return _IDENTITY_PROVIDERS[identity_provider_id]()
|
||||
@ -0,0 +1,10 @@
|
||||
from utils.pr_agent.identity_providers.identity_provider import (Eligibility,
|
||||
IdentityProvider)
|
||||
|
||||
|
||||
class DefaultIdentityProvider(IdentityProvider):
|
||||
def verify_eligibility(self, git_provider, git_provider_id, pr_url):
|
||||
return Eligibility.ELIGIBLE
|
||||
|
||||
def inc_invocation_count(self, git_provider, git_provider_id):
|
||||
pass
|
||||
18
apps/utils/pr_agent/identity_providers/identity_provider.py
Normal file
18
apps/utils/pr_agent/identity_providers/identity_provider.py
Normal file
@ -0,0 +1,18 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Eligibility(Enum):
|
||||
NOT_ELIGIBLE = 0
|
||||
ELIGIBLE = 1
|
||||
TRIAL = 2
|
||||
|
||||
|
||||
class IdentityProvider(ABC):
|
||||
@abstractmethod
|
||||
def verify_eligibility(self, git_provider, git_provier_id, pr_url):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inc_invocation_count(self, git_provider, git_provider_id):
|
||||
pass
|
||||
64
apps/utils/pr_agent/log/__init__.py
Normal file
64
apps/utils/pr_agent/log/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
class LoggingFormat(str, Enum):
|
||||
CONSOLE = "CONSOLE"
|
||||
JSON = "JSON"
|
||||
|
||||
|
||||
def json_format(record: dict) -> str:
|
||||
return record["message"]
|
||||
|
||||
|
||||
def analytics_filter(record: dict) -> bool:
|
||||
return record.get("extra", {}).get("analytics", False)
|
||||
|
||||
|
||||
def inv_analytics_filter(record: dict) -> bool:
|
||||
return not record.get("extra", {}).get("analytics", False)
|
||||
|
||||
|
||||
def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE):
|
||||
level: int = logging.getLevelName(level.upper())
|
||||
if type(level) is not int:
|
||||
level = logging.INFO
|
||||
|
||||
if fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0": # better debugging github_app
|
||||
logger.remove(None)
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
filter=inv_analytics_filter,
|
||||
level=level,
|
||||
format="{message}",
|
||||
colorize=False,
|
||||
serialize=True,
|
||||
)
|
||||
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
|
||||
logger.remove(None)
|
||||
logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter)
|
||||
|
||||
log_folder = get_settings().get("CONFIG.ANALYTICS_FOLDER", "")
|
||||
if log_folder:
|
||||
pid = os.getpid()
|
||||
log_file = os.path.join(log_folder, f"pr-agent.{pid}.log")
|
||||
logger.add(
|
||||
log_file,
|
||||
filter=analytics_filter,
|
||||
level=level,
|
||||
format="{message}",
|
||||
colorize=False,
|
||||
serialize=True,
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(*args, **kwargs):
|
||||
return logger
|
||||
17
apps/utils/pr_agent/secret_providers/__init__.py
Normal file
17
apps/utils/pr_agent/secret_providers/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def get_secret_provider():
|
||||
if not get_settings().get("CONFIG.SECRET_PROVIDER"):
|
||||
return None
|
||||
|
||||
provider_id = get_settings().config.secret_provider
|
||||
if provider_id == 'google_cloud_storage':
|
||||
try:
|
||||
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import \
|
||||
GoogleCloudStorageSecretProvider
|
||||
return GoogleCloudStorageSecretProvider()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e
|
||||
else:
|
||||
raise ValueError("Unknown SECRET_PROVIDER")
|
||||
@ -0,0 +1,34 @@
|
||||
import ujson
|
||||
from google.cloud import storage
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
from utils.pr_agent.secret_providers.secret_provider import SecretProvider
|
||||
|
||||
|
||||
class GoogleCloudStorageSecretProvider(SecretProvider):
|
||||
def __init__(self):
|
||||
try:
|
||||
self.client = storage.Client.from_service_account_info(ujson.loads(get_settings().google_cloud_storage.
|
||||
service_account))
|
||||
self.bucket_name = get_settings().google_cloud_storage.bucket_name
|
||||
self.bucket = self.client.bucket(self.bucket_name)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to initialize Google Cloud Storage Secret Provider: {e}")
|
||||
raise e
|
||||
|
||||
def get_secret(self, secret_name: str) -> str:
|
||||
try:
|
||||
blob = self.bucket.blob(secret_name)
|
||||
return blob.download_as_string()
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to get secret {secret_name} from Google Cloud Storage: {e}")
|
||||
return ""
|
||||
|
||||
def store_secret(self, secret_name: str, secret_value: str):
|
||||
try:
|
||||
blob = self.bucket.blob(secret_name)
|
||||
blob.upload_from_string(secret_value)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to store secret {secret_name} in Google Cloud Storage: {e}")
|
||||
raise e
|
||||
12
apps/utils/pr_agent/secret_providers/secret_provider.py
Normal file
12
apps/utils/pr_agent/secret_providers/secret_provider.py
Normal file
@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SecretProvider(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_secret(self, secret_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def store_secret(self, secret_name: str, secret_value: str):
|
||||
pass
|
||||
0
apps/utils/pr_agent/servers/__init__.py
Normal file
0
apps/utils/pr_agent/servers/__init__.py
Normal file
34
apps/utils/pr_agent/servers/atlassian-connect.json
Normal file
34
apps/utils/pr_agent/servers/atlassian-connect.json
Normal file
@ -0,0 +1,34 @@
|
||||
{
|
||||
"name": "CodiumAI PR-Agent",
|
||||
"description": "CodiumAI PR-Agent",
|
||||
"key": "app_key",
|
||||
"vendor": {
|
||||
"name": "CodiumAI",
|
||||
"url": "https://codium.ai"
|
||||
},
|
||||
"authentication": {
|
||||
"type": "jwt"
|
||||
},
|
||||
"baseUrl": "base_url",
|
||||
"lifecycle": {
|
||||
"installed": "/installed",
|
||||
"uninstalled": "/uninstalled"
|
||||
},
|
||||
"scopes": [
|
||||
"account",
|
||||
"repository:write",
|
||||
"pullrequest:write",
|
||||
"wiki"
|
||||
],
|
||||
"contexts": [
|
||||
"account"
|
||||
],
|
||||
"modules": {
|
||||
"webhooks": [
|
||||
{
|
||||
"event": "*",
|
||||
"url": "/webhook"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
148
apps/utils/pr_agent/servers/azuredevops_server_webhook.py
Normal file
148
apps/utils/pr_agent/servers/azuredevops_server_webhook.py
Normal file
@ -0,0 +1,148 @@
|
||||
# This file contains the code for the Azure DevOps Server webhook server.
|
||||
# The server listens for incoming webhooks from Azure DevOps Server and forwards them to the PR Agent.
|
||||
# ADO webhook documentation: https://learn.microsoft.com/en-us/azure/devops/service-hooks/services/webhooks?view=azure-devops
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
from urllib.parse import unquote
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from starlette import status
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent, command2class
|
||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
security = HTTPBasic()
|
||||
router = APIRouter()
|
||||
available_commands_rgx = re.compile(r"^\/(" + "|".join(command2class.keys()) + r")\s*")
|
||||
azure_devops_server = get_settings().get("azure_devops_server")
|
||||
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
|
||||
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
|
||||
|
||||
def handle_request(
|
||||
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
|
||||
):
|
||||
log_context["action"] = body
|
||||
log_context["api_url"] = url
|
||||
|
||||
async def inner():
|
||||
try:
|
||||
with get_logger().contextualize(**log_context):
|
||||
await PRAgent().handle_request(url, body)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle webhook: {e}")
|
||||
|
||||
background_tasks.add_task(inner)
|
||||
|
||||
|
||||
# currently only basic auth is supported with azure webhooks
|
||||
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
|
||||
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
|
||||
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
|
||||
if not (is_user_ok and is_pass_ok):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Incorrect username or password.',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
|
||||
|
||||
async def _perform_commands_azure(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
|
||||
return
|
||||
commands = get_settings().get(f"azure_devops_server.{commands_conf}")
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
for command in commands:
|
||||
try:
|
||||
split_command = command.split(" ")
|
||||
command = split_command[0]
|
||||
args = split_command[1:]
|
||||
other_args = update_settings_from_args(args)
|
||||
new_command = ' '.join([command] + other_args)
|
||||
get_logger().info(f"Performing command: {new_command}")
|
||||
with get_logger().contextualize(**log_context):
|
||||
await agent.handle_request(api_url, new_command)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to perform command {command}: {e}")
|
||||
|
||||
|
||||
@router.post("/", dependencies=[Depends(authorize)])
|
||||
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
log_context = {"server_type": "azure_devops_server"}
|
||||
data = await request.json()
|
||||
get_logger().info(json.dumps(data))
|
||||
|
||||
actions = []
|
||||
if data["eventType"] == "git.pullrequest.created":
|
||||
# API V1 (latest)
|
||||
pr_url = unquote(data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git"))
|
||||
log_context["event"] = data["eventType"]
|
||||
log_context["api_url"] = pr_url
|
||||
await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context)
|
||||
return
|
||||
elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]:
|
||||
if available_commands_rgx.match(data["resource"]["comment"]["content"]):
|
||||
if(data["resourceVersion"] == "2.0"):
|
||||
repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
|
||||
pr_url = unquote(f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}')
|
||||
actions = [data["resource"]["comment"]["content"]]
|
||||
else:
|
||||
# API V1 not supported as it does not contain the PR URL
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})),
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=json.dumps({"message": "Unsupported command"}),
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
content=json.dumps({"message": "Unsupported event"}),
|
||||
)
|
||||
|
||||
log_context["event"] = data["eventType"]
|
||||
log_context["api_url"] = pr_url
|
||||
|
||||
for action in actions:
|
||||
try:
|
||||
handle_request(background_tasks, pr_url, action, log_context)
|
||||
except Exception as e:
|
||||
get_logger().error("Azure DevOps Trigger failed. Error:" + str(e))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=json.dumps({"message": "Internal server error"}),
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggered successfully"})
|
||||
)
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
def start():
|
||||
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
|
||||
app.include_router(router)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
|
||||
|
||||
if __name__ == "__main__":
|
||||
start()
|
||||
272
apps/utils/pr_agent/servers/bitbucket_app.py
Normal file
272
apps/utils/pr_agent/servers/bitbucket_app.py
Normal file
@ -0,0 +1,272 @@
|
||||
import base64
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, Request, Response
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette_context import context
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||
from utils.pr_agent.config_loader import get_settings, global_settings
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.identity_providers import get_identity_provider
|
||||
from utils.pr_agent.identity_providers.identity_provider import Eligibility
|
||||
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from utils.pr_agent.secret_providers import get_secret_provider
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
router = APIRouter()
|
||||
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||
|
||||
|
||||
async def get_bearer_token(shared_secret: str, client_key: str):
|
||||
try:
|
||||
now = int(time.time())
|
||||
url = "https://bitbucket.org/site/oauth2/access_token"
|
||||
canonical_url = "GET&/site/oauth2/access_token&"
|
||||
qsh = hashlib.sha256(canonical_url.encode("utf-8")).hexdigest()
|
||||
app_key = get_settings().bitbucket.app_key
|
||||
|
||||
payload = {
|
||||
"iss": app_key,
|
||||
"iat": now,
|
||||
"exp": now + 240,
|
||||
"qsh": qsh,
|
||||
"sub": client_key,
|
||||
}
|
||||
token = jwt.encode(payload, shared_secret, algorithm="HS256")
|
||||
payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt'
|
||||
headers = {
|
||||
'Authorization': f'JWT {token}',
|
||||
'Content-Type': 'application/x-www-form-urlencoded'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
bearer_token = response.json()["access_token"]
|
||||
return bearer_token
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get bearer token: {e}")
|
||||
raise e
|
||||
|
||||
@router.get("/")
|
||||
async def handle_manifest(request: Request, response: Response):
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
manifest = open(os.path.join(cur_dir, "atlassian-connect.json"), "rt").read()
|
||||
try:
|
||||
manifest = manifest.replace("app_key", get_settings().bitbucket.app_key)
|
||||
manifest = manifest.replace("base_url", get_settings().bitbucket.base_url)
|
||||
except:
|
||||
get_logger().error("Failed to replace api_key in Bitbucket manifest, trying to continue")
|
||||
manifest_obj = json.loads(manifest)
|
||||
return JSONResponse(manifest_obj)
|
||||
|
||||
|
||||
def _get_username(data):
|
||||
actor = data.get("data", {}).get("actor", {})
|
||||
if actor:
|
||||
if "username" in actor:
|
||||
return actor["username"]
|
||||
elif "display_name" in actor:
|
||||
return actor["display_name"]
|
||||
elif "nickname" in actor:
|
||||
return actor["nickname"]
|
||||
return ""
|
||||
|
||||
|
||||
async def _perform_commands_bitbucket(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
|
||||
return
|
||||
if data.get("event", "") == "pullrequest:created":
|
||||
if not should_process_pr_logic(data):
|
||||
return
|
||||
commands = get_settings().get(f"bitbucket_app.{commands_conf}", {})
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
for command in commands:
|
||||
try:
|
||||
split_command = command.split(" ")
|
||||
command = split_command[0]
|
||||
args = split_command[1:]
|
||||
other_args = update_settings_from_args(args)
|
||||
new_command = ' '.join([command] + other_args)
|
||||
get_logger().info(f"Performing command: {new_command}")
|
||||
with get_logger().contextualize(**log_context):
|
||||
await agent.handle_request(api_url, new_command)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to perform command {command}: {e}")
|
||||
|
||||
|
||||
def is_bot_user(data) -> bool:
|
||||
try:
|
||||
actor = data.get("data", {}).get("actor", {})
|
||||
# allow actor type: user . if it's "AppUser" or "team" then it is a bot user
|
||||
allowed_actor_types = {"user"}
|
||||
if actor and actor["type"].lower() not in allowed_actor_types:
|
||||
get_logger().info(f"BitBucket actor type is not 'user', skipping: {actor}")
|
||||
return True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'is_bot_user' logic: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def should_process_pr_logic(data) -> bool:
|
||||
try:
|
||||
pr_data = data.get("data", {}).get("pullrequest", {})
|
||||
title = pr_data.get("title", "")
|
||||
source_branch = pr_data.get("source", {}).get("branch", {}).get("name", "")
|
||||
target_branch = pr_data.get("destination", {}).get("branch", {}).get("name", "")
|
||||
sender = _get_username(data)
|
||||
|
||||
# logic to ignore PRs from specific users
|
||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||
if ignore_pr_users and sender:
|
||||
if sender in ignore_pr_users:
|
||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific titles
|
||||
if title:
|
||||
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||
if not isinstance(ignore_pr_title_re, list):
|
||||
ignore_pr_title_re = [ignore_pr_title_re]
|
||||
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
|
||||
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
|
||||
return False
|
||||
|
||||
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
||||
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
||||
if (ignore_pr_source_branches or ignore_pr_target_branches):
|
||||
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
|
||||
return False
|
||||
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
|
||||
return False
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||
return True
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request):
|
||||
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
|
||||
log_context = {"server_type": "bitbucket_app", "app_name": app_name}
|
||||
get_logger().debug(request.headers)
|
||||
jwt_header = request.headers.get("authorization", None)
|
||||
if jwt_header:
|
||||
input_jwt = jwt_header.split(" ")[1]
|
||||
data = await request.json()
|
||||
get_logger().debug(data)
|
||||
|
||||
async def inner():
|
||||
try:
|
||||
# ignore bot users
|
||||
if is_bot_user(data):
|
||||
return "OK"
|
||||
|
||||
# Check if the PR should be processed
|
||||
if data.get("event", "") == "pullrequest:created":
|
||||
if not should_process_pr_logic(data):
|
||||
return "OK"
|
||||
|
||||
# Get the username of the sender
|
||||
log_context["sender"] = _get_username(data)
|
||||
|
||||
sender_id = data.get("data", {}).get("actor", {}).get("account_id", "")
|
||||
log_context["sender_id"] = sender_id
|
||||
jwt_parts = input_jwt.split(".")
|
||||
claim_part = jwt_parts[1]
|
||||
claim_part += "=" * (-len(claim_part) % 4)
|
||||
decoded_claims = base64.urlsafe_b64decode(claim_part)
|
||||
claims = json.loads(decoded_claims)
|
||||
client_key = claims["iss"]
|
||||
secrets = json.loads(secret_provider.get_secret(client_key))
|
||||
shared_secret = secrets["shared_secret"]
|
||||
jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"])
|
||||
bearer_token = await get_bearer_token(shared_secret, client_key)
|
||||
context['bitbucket_bearer_token'] = bearer_token
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
event = data["event"]
|
||||
agent = PRAgent()
|
||||
if event == "pullrequest:created":
|
||||
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
||||
log_context["api_url"] = pr_url
|
||||
log_context["event"] = "pull_request"
|
||||
if pr_url:
|
||||
with get_logger().contextualize(**log_context):
|
||||
apply_repo_settings(pr_url)
|
||||
if get_identity_provider().verify_eligibility("bitbucket",
|
||||
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
if get_settings().get("bitbucket_app.pr_commands"):
|
||||
await _perform_commands_bitbucket("pr_commands", PRAgent(), pr_url, log_context, data)
|
||||
elif event == "pullrequest:comment_created":
|
||||
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
||||
log_context["api_url"] = pr_url
|
||||
log_context["event"] = "comment"
|
||||
comment_body = data["data"]["comment"]["content"]["raw"]
|
||||
with get_logger().contextualize(**log_context):
|
||||
if get_identity_provider().verify_eligibility("bitbucket",
|
||||
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
await agent.handle_request(pr_url, comment_body)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle webhook: {e}")
|
||||
background_tasks.add_task(inner)
|
||||
return "OK"
|
||||
|
||||
@router.get("/webhook")
|
||||
async def handle_github_webhooks(request: Request, response: Response):
|
||||
return "Webhook server online!"
|
||||
|
||||
@router.post("/installed")
|
||||
async def handle_installed_webhooks(request: Request, response: Response):
|
||||
try:
|
||||
get_logger().info("handle_installed_webhooks")
|
||||
get_logger().info(request.headers)
|
||||
data = await request.json()
|
||||
get_logger().info(data)
|
||||
shared_secret = data["sharedSecret"]
|
||||
client_key = data["clientKey"]
|
||||
username = data["principal"]["username"]
|
||||
secrets = {
|
||||
"shared_secret": shared_secret,
|
||||
"client_key": client_key
|
||||
}
|
||||
secret_provider.store_secret(username, json.dumps(secrets))
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to register user: {e}")
|
||||
return JSONResponse({"error": "Unable to register user"}, status_code=500)
|
||||
|
||||
@router.post("/uninstalled")
|
||||
async def handle_uninstalled_webhooks(request: Request, response: Response):
|
||||
get_logger().info("handle_uninstalled_webhooks")
|
||||
|
||||
data = await request.json()
|
||||
get_logger().info(data)
|
||||
|
||||
|
||||
def start():
|
||||
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
||||
get_settings().set("CONFIG.GIT_PROVIDER", "bitbucket")
|
||||
get_settings().set("PR_DESCRIPTION.PUBLISH_DESCRIPTION_AS_COMMENT", True)
|
||||
middleware = [Middleware(RawContextMiddleware)]
|
||||
app = FastAPI(middleware=middleware)
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "3000")))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start()
|
||||
164
apps/utils/pr_agent/servers/bitbucket_server_webhook.py
Normal file
164
apps/utils/pr_agent/servers/bitbucket_server_webhook.py
Normal file
@ -0,0 +1,164 @@
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import RedirectResponse
|
||||
from starlette import status
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from utils.pr_agent.servers.utils import verify_signature
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def handle_request(
|
||||
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
|
||||
):
|
||||
log_context["action"] = body
|
||||
log_context["api_url"] = url
|
||||
|
||||
async def inner():
|
||||
try:
|
||||
with get_logger().contextualize(**log_context):
|
||||
await PRAgent().handle_request(url, body)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle webhook: {e}")
|
||||
|
||||
background_tasks.add_task(inner)
|
||||
|
||||
@router.post("/")
|
||||
async def redirect_to_webhook():
|
||||
return RedirectResponse(url="/webhook")
|
||||
|
||||
@router.post("/webhook")
|
||||
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
log_context = {"server_type": "bitbucket_server"}
|
||||
data = await request.json()
|
||||
get_logger().info(json.dumps(data))
|
||||
|
||||
webhook_secret = get_settings().get("BITBUCKET_SERVER.WEBHOOK_SECRET", None)
|
||||
if webhook_secret:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes.decode('utf-8') == '{"test": true}':
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "connection test successful"})
|
||||
)
|
||||
signature_header = request.headers.get("x-hub-signature", None)
|
||||
verify_signature(body_bytes, webhook_secret, signature_header)
|
||||
|
||||
pr_id = data["pullRequest"]["id"]
|
||||
repository_name = data["pullRequest"]["toRef"]["repository"]["slug"]
|
||||
project_name = data["pullRequest"]["toRef"]["repository"]["project"]["key"]
|
||||
bitbucket_server = get_settings().get("BITBUCKET_SERVER.URL")
|
||||
pr_url = f"{bitbucket_server}/projects/{project_name}/repos/{repository_name}/pull-requests/{pr_id}"
|
||||
|
||||
log_context["api_url"] = pr_url
|
||||
log_context["event"] = "pull_request"
|
||||
|
||||
commands_to_run = []
|
||||
|
||||
if data["eventKey"] == "pr:opened":
|
||||
apply_repo_settings(pr_url)
|
||||
if get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", **log_context)
|
||||
return
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
commands_to_run.extend(_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS'))
|
||||
elif data["eventKey"] == "pr:comment:added":
|
||||
commands_to_run.append(data["comment"]["text"])
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=json.dumps({"message": "Unsupported event"}),
|
||||
)
|
||||
|
||||
async def inner():
|
||||
try:
|
||||
await _run_commands_sequentially(commands_to_run, pr_url, log_context)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle webhook: {e}")
|
||||
|
||||
background_tasks.add_task(inner)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})
|
||||
)
|
||||
|
||||
|
||||
async def _run_commands_sequentially(commands: List[str], url: str, log_context: dict):
|
||||
get_logger().info(f"Running commands sequentially: {commands}")
|
||||
if commands is None:
|
||||
return
|
||||
|
||||
for command in commands:
|
||||
try:
|
||||
body = _process_command(command, url)
|
||||
|
||||
log_context["action"] = body
|
||||
log_context["api_url"] = url
|
||||
|
||||
with get_logger().contextualize(**log_context):
|
||||
await PRAgent().handle_request(url, body)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle command: {command} , error: {e}")
|
||||
|
||||
def _process_command(command: str, url) -> str:
|
||||
# don't think we need this
|
||||
apply_repo_settings(url)
|
||||
# Process the command string
|
||||
split_command = command.split(" ")
|
||||
command = split_command[0]
|
||||
args = split_command[1:]
|
||||
# do I need this? if yes, shouldn't this be done in PRAgent?
|
||||
other_args = update_settings_from_args(args)
|
||||
new_command = ' '.join([command] + other_args)
|
||||
return new_command
|
||||
|
||||
|
||||
def _to_list(command_string: str) -> list:
|
||||
try:
|
||||
# Use ast.literal_eval to safely parse the string into a list
|
||||
commands = ast.literal_eval(command_string)
|
||||
# Check if the parsed object is a list of strings
|
||||
if isinstance(commands, list) and all(isinstance(cmd, str) for cmd in commands):
|
||||
return commands
|
||||
else:
|
||||
raise ValueError("Parsed data is not a list of strings.")
|
||||
except (SyntaxError, ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid command string: {e}")
|
||||
|
||||
|
||||
def _get_commands_list_from_settings(setting_key:str ) -> list:
|
||||
try:
|
||||
return get_settings().get(setting_key, [])
|
||||
except ValueError as e:
|
||||
get_logger().error(f"Failed to get commands list from settings {setting_key}: {e}")
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def start():
|
||||
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
|
||||
app.include_router(router)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start()
|
||||
77
apps/utils/pr_agent/servers/gerrit_server.py
Normal file
77
apps/utils/pr_agent/servers/gerrit_server.py
Normal file
@ -0,0 +1,77 @@
|
||||
import copy
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware import Middleware
|
||||
from starlette_context import context
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.config_loader import get_settings, global_settings
|
||||
from utils.pr_agent.log import get_logger, setup_logger
|
||||
|
||||
setup_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class Action(str, Enum):
|
||||
review = "review"
|
||||
describe = "describe"
|
||||
ask = "ask"
|
||||
improve = "improve"
|
||||
reflect = "reflect"
|
||||
answer = "answer"
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
refspec: str
|
||||
project: str
|
||||
msg: str
|
||||
|
||||
|
||||
@router.post("/api/v1/gerrit/{action}")
|
||||
async def handle_gerrit_request(action: Action, item: Item):
|
||||
get_logger().debug("Received a Gerrit request")
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
|
||||
if action == Action.ask:
|
||||
if not item.msg:
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="msg is required for ask command"
|
||||
)
|
||||
await PRAgent().handle_request(
|
||||
f"{item.project}:{item.refspec}",
|
||||
f"/{item.msg.strip()}"
|
||||
)
|
||||
|
||||
|
||||
async def get_body(request):
|
||||
try:
|
||||
body = await request.json()
|
||||
except JSONDecodeError as e:
|
||||
get_logger().error("Error parsing request body", e)
|
||||
return {}
|
||||
return body
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def start():
|
||||
# to prevent adding help messages with the output
|
||||
get_settings().set("CONFIG.CLI_MODE", True)
|
||||
middleware = [Middleware(RawContextMiddleware)]
|
||||
app = FastAPI(middleware=middleware)
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=3000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start()
|
||||
160
apps/utils/pr_agent/servers/github_action_runner.py
Normal file
160
apps/utils/pr_agent/servers/github_action_runner.py
Normal file
@ -0,0 +1,160 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import get_git_provider
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
from utils.pr_agent.servers.github_app import handle_line_comments
|
||||
from utils.pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||
from utils.pr_agent.tools.pr_description import PRDescription
|
||||
from utils.pr_agent.tools.pr_reviewer import PRReviewer
|
||||
|
||||
|
||||
def is_true(value: Union[str, bool]) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() == 'true'
|
||||
return False
|
||||
|
||||
|
||||
def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str, bool]:
|
||||
try:
|
||||
value = get_settings().get(key, default)
|
||||
except AttributeError: # TBD still need to debug why this happens on GitHub Actions
|
||||
value = os.getenv(key, None) or os.getenv(key.upper(), None) or os.getenv(key.lower(), None) or default
|
||||
return value
|
||||
|
||||
|
||||
async def run_action():
|
||||
# Get environment variables
|
||||
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
|
||||
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
|
||||
OPENAI_KEY = os.environ.get('OPENAI_KEY') or os.environ.get('OPENAI.KEY')
|
||||
OPENAI_ORG = os.environ.get('OPENAI_ORG') or os.environ.get('OPENAI.ORG')
|
||||
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
|
||||
# get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
||||
|
||||
# Check if required environment variables are set
|
||||
if not GITHUB_EVENT_NAME:
|
||||
print("GITHUB_EVENT_NAME not set")
|
||||
return
|
||||
if not GITHUB_EVENT_PATH:
|
||||
print("GITHUB_EVENT_PATH not set")
|
||||
return
|
||||
if not GITHUB_TOKEN:
|
||||
print("GITHUB_TOKEN not set")
|
||||
return
|
||||
|
||||
# Set the environment variables in the settings
|
||||
if OPENAI_KEY:
|
||||
get_settings().set("OPENAI.KEY", OPENAI_KEY)
|
||||
else:
|
||||
# Might not be set if the user is using models not from OpenAI
|
||||
print("OPENAI_KEY not set")
|
||||
if OPENAI_ORG:
|
||||
get_settings().set("OPENAI.ORG", OPENAI_ORG)
|
||||
get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
|
||||
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user")
|
||||
enable_output = get_setting_or_env("GITHUB_ACTION_CONFIG.ENABLE_OUTPUT", True)
|
||||
get_settings().set("GITHUB_ACTION_CONFIG.ENABLE_OUTPUT", enable_output)
|
||||
|
||||
# Load the event payload
|
||||
try:
|
||||
with open(GITHUB_EVENT_PATH, 'r') as f:
|
||||
event_payload = json.load(f)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
print(f"Failed to parse JSON: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
get_logger().info("Applying repo settings")
|
||||
pr_url = event_payload.get("pull_request", {}).get("html_url")
|
||||
if pr_url:
|
||||
apply_repo_settings(pr_url)
|
||||
get_logger().info(f"enable_custom_labels: {get_settings().config.enable_custom_labels}")
|
||||
except Exception as e:
|
||||
get_logger().info(f"github action: failed to apply repo settings: {e}")
|
||||
|
||||
# Handle pull request opened event
|
||||
if GITHUB_EVENT_NAME == "pull_request" or GITHUB_EVENT_NAME == "pull_request_target":
|
||||
action = event_payload.get("action")
|
||||
|
||||
# Retrieve the list of actions from the configuration
|
||||
pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"])
|
||||
|
||||
if action in pr_actions:
|
||||
pr_url = event_payload.get("pull_request", {}).get("url")
|
||||
if pr_url:
|
||||
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
|
||||
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
|
||||
if auto_review is None:
|
||||
auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None)
|
||||
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
|
||||
if auto_describe is None:
|
||||
auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None)
|
||||
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
|
||||
if auto_improve is None:
|
||||
auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None)
|
||||
|
||||
# Set the configuration for auto actions
|
||||
get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto
|
||||
get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled
|
||||
get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}")
|
||||
|
||||
# invoke by default all three tools
|
||||
if auto_describe is None or is_true(auto_describe):
|
||||
await PRDescription(pr_url).run()
|
||||
if auto_review is None or is_true(auto_review):
|
||||
await PRReviewer(pr_url).run()
|
||||
if auto_improve is None or is_true(auto_improve):
|
||||
await PRCodeSuggestions(pr_url).run()
|
||||
else:
|
||||
get_logger().info(f"Skipping action: {action}")
|
||||
|
||||
# Handle issue comment event
|
||||
elif GITHUB_EVENT_NAME == "issue_comment" or GITHUB_EVENT_NAME == "pull_request_review_comment":
|
||||
action = event_payload.get("action")
|
||||
if action in ["created", "edited"]:
|
||||
comment_body = event_payload.get("comment", {}).get("body")
|
||||
try:
|
||||
if GITHUB_EVENT_NAME == "pull_request_review_comment":
|
||||
if '/ask' in comment_body:
|
||||
comment_body = handle_line_comments(event_payload, comment_body)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle line comments: {e}")
|
||||
return
|
||||
if comment_body:
|
||||
is_pr = False
|
||||
disable_eyes = False
|
||||
# check if issue is pull request
|
||||
if event_payload.get("issue", {}).get("pull_request"):
|
||||
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
|
||||
is_pr = True
|
||||
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
|
||||
url = event_payload.get("comment", {}).get("pull_request_url")
|
||||
is_pr = True
|
||||
disable_eyes = True
|
||||
else:
|
||||
url = event_payload.get("issue", {}).get("url")
|
||||
|
||||
if url:
|
||||
body = comment_body.strip().lower()
|
||||
comment_id = event_payload.get("comment", {}).get("id")
|
||||
provider = get_git_provider()(pr_url=url)
|
||||
if is_pr:
|
||||
await PRAgent().handle_request(
|
||||
url, body, notify=lambda: provider.add_eyes_reaction(
|
||||
comment_id, disable_eyes=disable_eyes
|
||||
)
|
||||
)
|
||||
else:
|
||||
await PRAgent().handle_request(url, body)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(run_action())
|
||||
424
apps/utils/pr_agent/servers/github_app.py
Normal file
424
apps/utils/pr_agent/servers/github_app.py
Normal file
@ -0,0 +1,424 @@
|
||||
import asyncio.locks
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.middleware import Middleware
|
||||
from starlette_context import context
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||
from utils.pr_agent.config_loader import get_settings, global_settings
|
||||
from utils.pr_agent.git_providers import (get_git_provider,
|
||||
get_git_provider_with_context)
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.identity_providers import get_identity_provider
|
||||
from utils.pr_agent.identity_providers.identity_provider import Eligibility
|
||||
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from utils.pr_agent.servers.utils import DefaultDictWithTimeout, verify_signature
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
build_number_path = os.path.join(base_path, "build_number.txt")
|
||||
if os.path.exists(build_number_path):
|
||||
with open(build_number_path) as f:
|
||||
build_number = f.read().strip()
|
||||
else:
|
||||
build_number = "unknown"
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/api/v1/github_webhooks")
|
||||
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request, response: Response):
|
||||
"""
|
||||
Receives and processes incoming GitHub webhook requests.
|
||||
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
|
||||
processing.
|
||||
"""
|
||||
get_logger().debug("Received a GitHub webhook")
|
||||
|
||||
body = await get_body(request)
|
||||
|
||||
installation_id = body.get("installation", {}).get("id")
|
||||
context["installation_id"] = installation_id
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
context["git_provider"] = {}
|
||||
background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None))
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/api/v1/marketplace_webhooks")
|
||||
async def handle_marketplace_webhooks(request: Request, response: Response):
|
||||
body = await get_body(request)
|
||||
get_logger().info(f'Request body:\n{body}')
|
||||
|
||||
|
||||
async def get_body(request):
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
get_logger().error("Error parsing request body", e)
|
||||
raise HTTPException(status_code=400, detail="Error parsing request body") from e
|
||||
webhook_secret = getattr(get_settings().github, 'webhook_secret', None)
|
||||
if webhook_secret:
|
||||
body_bytes = await request.body()
|
||||
signature_header = request.headers.get('x-hub-signature-256', None)
|
||||
verify_signature(body_bytes, webhook_secret, signature_header)
|
||||
return body
|
||||
|
||||
|
||||
_duplicate_push_triggers = DefaultDictWithTimeout(ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
|
||||
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.Condition, ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
|
||||
|
||||
async def handle_comments_on_pr(body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent):
|
||||
if "comment" not in body:
|
||||
return {}
|
||||
comment_body = body.get("comment", {}).get("body")
|
||||
if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"):
|
||||
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
|
||||
comment_body_split = comment_body.split('/ask')
|
||||
comment_body = '/ask' + comment_body_split[1] +' \n' +comment_body_split[0].strip().lstrip('>')
|
||||
get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}")
|
||||
else:
|
||||
get_logger().info("Ignoring comment not starting with /")
|
||||
return {}
|
||||
disable_eyes = False
|
||||
if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]:
|
||||
api_url = body["issue"]["pull_request"]["url"]
|
||||
elif "comment" in body and "pull_request_url" in body["comment"]:
|
||||
api_url = body["comment"]["pull_request_url"]
|
||||
try:
|
||||
if ('/ask' in comment_body and
|
||||
'subject_type' in body["comment"] and body["comment"]["subject_type"] == "line"):
|
||||
# comment on a code line in the "files changed" tab
|
||||
comment_body = handle_line_comments(body, comment_body)
|
||||
disable_eyes = True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle line comments: {e}")
|
||||
else:
|
||||
return {}
|
||||
log_context["api_url"] = api_url
|
||||
comment_id = body.get("comment", {}).get("id")
|
||||
provider = get_git_provider_with_context(pr_url=api_url)
|
||||
with get_logger().contextualize(**log_context):
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
|
||||
await agent.handle_request(api_url, comment_body,
|
||||
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
|
||||
else:
|
||||
get_logger().info(f"User {sender=} is not eligible to process comment on PR {api_url=}")
|
||||
|
||||
async def handle_new_pr_opened(body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent):
|
||||
title = body.get("pull_request", {}).get("title", "")
|
||||
|
||||
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
||||
if not (pull_request and api_url):
|
||||
get_logger().info(f"Invalid PR event: {action=} {api_url=}")
|
||||
return {}
|
||||
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review']
|
||||
# logic to ignore PRs with specific titles (e.g. "[Auto] ...")
|
||||
apply_repo_settings(api_url)
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
|
||||
else:
|
||||
get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}")
|
||||
|
||||
async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent):
|
||||
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
||||
if not (pull_request and api_url):
|
||||
return {}
|
||||
|
||||
apply_repo_settings(api_url) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
|
||||
if not get_settings().github_app.handle_push_trigger:
|
||||
return {}
|
||||
|
||||
# TODO: do we still want to get the list of commits to filter bot/merge commits?
|
||||
before_sha = body.get("before")
|
||||
after_sha = body.get("after")
|
||||
merge_commit_sha = pull_request.get("merge_commit_sha")
|
||||
if before_sha == after_sha:
|
||||
return {}
|
||||
if get_settings().github_app.push_trigger_ignore_merge_commits and after_sha == merge_commit_sha:
|
||||
return {}
|
||||
|
||||
# Prevent triggering multiple times for subsequent push triggers when one is enough:
|
||||
# The first push will trigger the processing, and if there's a second push in the meanwhile it will wait.
|
||||
# Any more events will be discarded, because they will all trigger the exact same processing on the PR.
|
||||
# We let the second event wait instead of discarding it because while the first event was being processed,
|
||||
# more commits may have been pushed that led to the subsequent events,
|
||||
# so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting.
|
||||
current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0)
|
||||
max_active_tasks = 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
|
||||
if current_active_tasks < max_active_tasks:
|
||||
# first task can enter, and second tasks too if backlog is enabled
|
||||
get_logger().info(
|
||||
f"Continue processing push trigger for {api_url=} because there are {current_active_tasks} active tasks"
|
||||
)
|
||||
_duplicate_push_triggers[api_url] += 1
|
||||
else:
|
||||
get_logger().info(
|
||||
f"Skipping push trigger for {api_url=} because another event already triggered the same processing"
|
||||
)
|
||||
return {}
|
||||
async with _pending_task_duplicate_push_conditions[api_url]:
|
||||
if current_active_tasks == 1:
|
||||
# second task waits
|
||||
get_logger().info(
|
||||
f"Waiting to process push trigger for {api_url=} because the first task is still in progress"
|
||||
)
|
||||
await _pending_task_duplicate_push_conditions[api_url].wait()
|
||||
get_logger().info(f"Finished waiting to process push trigger for {api_url=} - continue with flow")
|
||||
|
||||
try:
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
|
||||
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
|
||||
|
||||
finally:
|
||||
# release the waiting task block
|
||||
async with _pending_task_duplicate_push_conditions[api_url]:
|
||||
_pending_task_duplicate_push_conditions[api_url].notify(1)
|
||||
_duplicate_push_triggers[api_url] -= 1
|
||||
|
||||
|
||||
def handle_closed_pr(body, event, action, log_context):
|
||||
pull_request = body.get("pull_request", {})
|
||||
is_merged = pull_request.get("merged", False)
|
||||
if not is_merged:
|
||||
return
|
||||
api_url = pull_request.get("url", "")
|
||||
pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request)
|
||||
log_context["api_url"] = api_url
|
||||
get_logger().info("PR-Agent statistics for closed PR", analytics=True, pr_statistics=pr_statistics, **log_context)
|
||||
|
||||
|
||||
def get_log_context(body, event, action, build_number):
|
||||
sender = ""
|
||||
sender_id = ""
|
||||
sender_type = ""
|
||||
try:
|
||||
sender = body.get("sender", {}).get("login")
|
||||
sender_id = body.get("sender", {}).get("id")
|
||||
sender_type = body.get("sender", {}).get("type")
|
||||
repo = body.get("repository", {}).get("full_name", "")
|
||||
git_org = body.get("organization", {}).get("login", "")
|
||||
installation_id = body.get("installation", {}).get("id", "")
|
||||
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
|
||||
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app",
|
||||
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name,
|
||||
"repo": repo, "git_org": git_org, "installation_id": installation_id}
|
||||
except Exception as e:
|
||||
get_logger().error("Failed to get log context", e)
|
||||
log_context = {}
|
||||
return log_context, sender, sender_id, sender_type
|
||||
|
||||
|
||||
def is_bot_user(sender, sender_type):
|
||||
try:
|
||||
# logic to ignore PRs opened by bot
|
||||
if get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) and sender_type == "Bot":
|
||||
if 'pr-agent' not in sender:
|
||||
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
|
||||
return True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'is_bot_user' logic: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def should_process_pr_logic(body) -> bool:
|
||||
try:
|
||||
pull_request = body.get("pull_request", {})
|
||||
title = pull_request.get("title", "")
|
||||
pr_labels = pull_request.get("labels", [])
|
||||
source_branch = pull_request.get("head", {}).get("ref", "")
|
||||
target_branch = pull_request.get("base", {}).get("ref", "")
|
||||
sender = body.get("sender", {}).get("login")
|
||||
|
||||
# logic to ignore PRs from specific users
|
||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||
if ignore_pr_users and sender:
|
||||
if sender in ignore_pr_users:
|
||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific titles
|
||||
if title:
|
||||
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||
if not isinstance(ignore_pr_title_re, list):
|
||||
ignore_pr_title_re = [ignore_pr_title_re]
|
||||
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
|
||||
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific labels or source branches or target branches.
|
||||
ignore_pr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
|
||||
if pr_labels and ignore_pr_labels:
|
||||
labels = [label['name'] for label in pr_labels]
|
||||
if any(label in ignore_pr_labels for label in labels):
|
||||
labels_str = ", ".join(labels)
|
||||
get_logger().info(f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings")
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific source or target branches
|
||||
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
||||
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
||||
if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches):
|
||||
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
|
||||
return False
|
||||
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
|
||||
return False
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||
return True
|
||||
|
||||
|
||||
async def handle_request(body: Dict[str, Any], event: str):
|
||||
"""
|
||||
Handle incoming GitHub webhook requests.
|
||||
|
||||
Args:
|
||||
body: The request body.
|
||||
event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.).
|
||||
"""
|
||||
action = body.get("action") # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
|
||||
if not action:
|
||||
return {}
|
||||
agent = PRAgent()
|
||||
log_context, sender, sender_id, sender_type = get_log_context(body, event, action, build_number)
|
||||
|
||||
# logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches
|
||||
if is_bot_user(sender, sender_type) and 'check_run' not in body:
|
||||
return {}
|
||||
if action != 'created' and 'check_run' not in body:
|
||||
if not should_process_pr_logic(body):
|
||||
return {}
|
||||
|
||||
if 'check_run' in body: # handle failed checks
|
||||
# get_logger().debug(f'Request body', artifact=body, event=event) # added inside handle_checks
|
||||
pass
|
||||
# handle comments on PRs
|
||||
elif action == 'created':
|
||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent)
|
||||
# handle new PRs
|
||||
elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
|
||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent)
|
||||
elif event == "issue_comment" and 'edited' in action:
|
||||
pass # handle_checkbox_clicked
|
||||
# handle pull_request event with synchronize action - "push trigger" for new commits
|
||||
elif event == 'pull_request' and action == 'synchronize':
|
||||
await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent)
|
||||
elif event == 'pull_request' and action == 'closed':
|
||||
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
|
||||
handle_closed_pr(body, event, action, log_context)
|
||||
else:
|
||||
get_logger().info(f"event {event=} action {action=} does not require any handling")
|
||||
return {}
|
||||
|
||||
|
||||
def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str:
|
||||
if not comment_body:
|
||||
return ""
|
||||
start_line = body["comment"]["start_line"]
|
||||
end_line = body["comment"]["line"]
|
||||
start_line = end_line if not start_line else start_line
|
||||
question = comment_body.replace('/ask', '').strip()
|
||||
diff_hunk = body["comment"]["diff_hunk"]
|
||||
get_settings().set("ask_diff_hunk", diff_hunk)
|
||||
path = body["comment"]["path"]
|
||||
side = body["comment"]["side"]
|
||||
comment_id = body["comment"]["id"]
|
||||
if '/ask' in comment_body:
|
||||
comment_body = f"/ask_line --line_start={start_line} --line_end={end_line} --side={side} --file_name={path} --comment_id={comment_id} {question}"
|
||||
return comment_body
|
||||
|
||||
|
||||
def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tuple[Dict[str, Any], str]:
|
||||
invalid_result = {}, ""
|
||||
pull_request = body.get("pull_request")
|
||||
if not pull_request:
|
||||
return invalid_result
|
||||
api_url = pull_request.get("url")
|
||||
if not api_url:
|
||||
return invalid_result
|
||||
log_context["api_url"] = api_url
|
||||
if pull_request.get("draft", True) or pull_request.get("state") != "open":
|
||||
return invalid_result
|
||||
if action in ("review_requested", "synchronize") and pull_request.get("created_at") == pull_request.get("updated_at"):
|
||||
# avoid double reviews when opening a PR for the first time
|
||||
return invalid_result
|
||||
return pull_request, api_url
|
||||
|
||||
|
||||
async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body: dict, api_url: str,
|
||||
log_context: dict):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
|
||||
return
|
||||
if not should_process_pr_logic(body): # Here we already updated the configuration with the repo settings
|
||||
return {}
|
||||
commands = get_settings().get(f"github_app.{commands_conf}")
|
||||
if not commands:
|
||||
get_logger().info(f"New PR, but no auto commands configured")
|
||||
return
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
for command in commands:
|
||||
split_command = command.split(" ")
|
||||
command = split_command[0]
|
||||
args = split_command[1:]
|
||||
other_args = update_settings_from_args(args)
|
||||
new_command = ' '.join([command] + other_args)
|
||||
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}")
|
||||
await agent.handle_request(api_url, new_command)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
if get_settings().github_app.override_deployment_type:
|
||||
# Override the deployment type to app
|
||||
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "app")
|
||||
# get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
||||
middleware = [Middleware(RawContextMiddleware)]
|
||||
app = FastAPI(middleware=middleware)
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
def start():
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start()
|
||||
241
apps/utils/pr_agent/servers/github_polling.py
Normal file
241
apps/utils/pr_agent/servers/github_polling.py
Normal file
@ -0,0 +1,241 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import traceback
|
||||
from collections import deque
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import get_git_provider
|
||||
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
NOTIFICATION_URL = "https://api.github.com/notifications"
|
||||
|
||||
|
||||
async def mark_notification_as_read(headers, notification, session):
|
||||
async with session.patch(
|
||||
f"https://api.github.com/notifications/threads/{notification['id']}",
|
||||
headers=headers) as mark_read_response:
|
||||
if mark_read_response.status != 205:
|
||||
get_logger().error(
|
||||
f"Failed to mark notification as read. Status code: {mark_read_response.status}")
|
||||
|
||||
|
||||
def now() -> str:
|
||||
"""
|
||||
Get the current UTC time in ISO 8601 format.
|
||||
|
||||
Returns:
|
||||
str: The current UTC time in ISO 8601 format.
|
||||
"""
|
||||
now_utc = datetime.now(timezone.utc).isoformat()
|
||||
now_utc = now_utc.replace("+00:00", "Z")
|
||||
return now_utc
|
||||
|
||||
async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
||||
agent = PRAgent()
|
||||
success = await agent.handle_request(
|
||||
pr_url,
|
||||
rest_of_comment,
|
||||
notify=lambda: git_provider.add_eyes_reaction(comment_id)
|
||||
)
|
||||
return success
|
||||
|
||||
def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
||||
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider))
|
||||
|
||||
|
||||
def process_comment_sync(pr_url, rest_of_comment, comment_id):
|
||||
try:
|
||||
# Run the async handle_request in a separate function
|
||||
git_provider = get_git_provider()(pr_url=pr_url)
|
||||
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
|
||||
|
||||
async def process_comment(pr_url, rest_of_comment, comment_id):
|
||||
try:
|
||||
git_provider = get_git_provider()(pr_url=pr_url)
|
||||
git_provider.set_pr(pr_url)
|
||||
agent = PRAgent()
|
||||
success = await agent.handle_request(
|
||||
pr_url,
|
||||
rest_of_comment,
|
||||
notify=lambda: git_provider.add_eyes_reaction(comment_id)
|
||||
)
|
||||
get_logger().info(f"Finished processing comment for PR: {pr_url}")
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
|
||||
async def is_valid_notification(notification, headers, handled_ids, session, user_id):
|
||||
try:
|
||||
if 'reason' in notification and notification['reason'] == 'mention':
|
||||
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
|
||||
pr_url = notification['subject']['url']
|
||||
latest_comment = notification['subject']['latest_comment_url']
|
||||
if not latest_comment or not isinstance(latest_comment, str):
|
||||
get_logger().debug(f"no latest_comment")
|
||||
return False, handled_ids
|
||||
async with session.get(latest_comment, headers=headers) as comment_response:
|
||||
check_prev_comments = False
|
||||
user_tag = "@" + user_id
|
||||
if comment_response.status == 200:
|
||||
comment = await comment_response.json()
|
||||
if 'id' in comment:
|
||||
if comment['id'] in handled_ids:
|
||||
get_logger().debug(f"comment['id'] in handled_ids")
|
||||
return False, handled_ids
|
||||
else:
|
||||
handled_ids.add(comment['id'])
|
||||
if 'user' in comment and 'login' in comment['user']:
|
||||
if comment['user']['login'] == user_id:
|
||||
get_logger().debug(f"comment['user']['login'] == user_id")
|
||||
check_prev_comments = True
|
||||
comment_body = comment.get('body', '')
|
||||
if not comment_body:
|
||||
get_logger().debug(f"no comment_body")
|
||||
check_prev_comments = True
|
||||
else:
|
||||
if user_tag not in comment_body:
|
||||
get_logger().debug(f"user_tag not in comment_body")
|
||||
check_prev_comments = True
|
||||
else:
|
||||
get_logger().info(f"Polling, pr_url: {pr_url}",
|
||||
artifact={"comment": comment_body})
|
||||
|
||||
if not check_prev_comments:
|
||||
return True, handled_ids, comment, comment_body, pr_url, user_tag
|
||||
else: # we could not find the user tag in the latest comment. Check previous comments
|
||||
# get all comments in the PR
|
||||
requests_url = f"{pr_url}/comments".replace("pulls", "issues")
|
||||
comments_response = requests.get(requests_url, headers=headers)
|
||||
comments = comments_response.json()[::-1]
|
||||
max_comment_to_scan = 4
|
||||
for comment in comments[:max_comment_to_scan]:
|
||||
if 'user' in comment and 'login' in comment['user']:
|
||||
if comment['user']['login'] == user_id:
|
||||
continue
|
||||
comment_body = comment.get('body', '')
|
||||
if not comment_body:
|
||||
continue
|
||||
if user_tag in comment_body:
|
||||
get_logger().info("found user tag in previous comments")
|
||||
get_logger().info(f"Polling, pr_url: {pr_url}",
|
||||
artifact={"comment": comment_body})
|
||||
return True, handled_ids, comment, comment_body, pr_url, user_tag
|
||||
|
||||
get_logger().warning(f"Failed to fetch comments for PR: {pr_url}",
|
||||
artifact={"comments": comments})
|
||||
return False, handled_ids
|
||||
|
||||
return False, handled_ids
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Error processing polling notification",
|
||||
artifact={"notification": notification, "error": e})
|
||||
return False, handled_ids
|
||||
|
||||
|
||||
|
||||
async def polling_loop():
|
||||
"""
|
||||
Polls for notifications and handles them accordingly.
|
||||
"""
|
||||
handled_ids = set()
|
||||
since = [now()]
|
||||
last_modified = [None]
|
||||
git_provider = get_git_provider()()
|
||||
user_id = git_provider.get_user_id()
|
||||
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
||||
get_settings().set("pr_description.publish_description_as_comment", True)
|
||||
|
||||
try:
|
||||
deployment_type = get_settings().github.deployment_type
|
||||
token = get_settings().github.user_token
|
||||
except AttributeError:
|
||||
deployment_type = 'none'
|
||||
token = None
|
||||
|
||||
if deployment_type != 'user':
|
||||
raise ValueError("Deployment mode must be set to 'user' to get notifications")
|
||||
if not token:
|
||||
raise ValueError("User token must be set to get notifications")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(5)
|
||||
headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
params = {
|
||||
"participating": "true"
|
||||
}
|
||||
if since[0]:
|
||||
params["since"] = since[0]
|
||||
if last_modified[0]:
|
||||
headers["If-Modified-Since"] = last_modified[0]
|
||||
|
||||
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response:
|
||||
if response.status == 200:
|
||||
if 'Last-Modified' in response.headers:
|
||||
last_modified[0] = response.headers['Last-Modified']
|
||||
since[0] = None
|
||||
notifications = await response.json()
|
||||
if not notifications:
|
||||
continue
|
||||
get_logger().info(f"Received {len(notifications)} notifications")
|
||||
task_queue = deque()
|
||||
for notification in notifications:
|
||||
if not notification:
|
||||
continue
|
||||
# mark notification as read
|
||||
await mark_notification_as_read(headers, notification, session)
|
||||
|
||||
handled_ids.add(notification['id'])
|
||||
output = await is_valid_notification(notification, headers, handled_ids, session, user_id)
|
||||
if output[0]:
|
||||
_, handled_ids, comment, comment_body, pr_url, user_tag = output
|
||||
rest_of_comment = comment_body.split(user_tag)[1].strip()
|
||||
comment_id = comment['id']
|
||||
|
||||
# Add to the task queue
|
||||
get_logger().info(
|
||||
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}")
|
||||
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id)))
|
||||
get_logger().info(f"Queued comment processing for PR: {pr_url}")
|
||||
else:
|
||||
get_logger().debug(f"Skipping comment processing for PR")
|
||||
|
||||
max_allowed_parallel_tasks = 10
|
||||
if task_queue:
|
||||
processes = []
|
||||
for i, (func, args) in enumerate(task_queue): # Create parallel tasks
|
||||
p = multiprocessing.Process(target=func, args=args)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
if i > max_allowed_parallel_tasks:
|
||||
get_logger().error(
|
||||
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session")
|
||||
break
|
||||
task_queue.clear()
|
||||
|
||||
# Dont wait for all processes to complete. Move on to the next iteration
|
||||
# for p in processes:
|
||||
# p.join()
|
||||
|
||||
elif response.status != 304:
|
||||
print(f"Failed to fetch notifications. Status code: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
get_logger().error(f"Polling exception during processing of a notification: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(polling_loop())
|
||||
288
apps/utils/pr_agent/servers/gitlab_webhook.py
Normal file
288
apps/utils/pr_agent/servers/gitlab_webhook.py
Normal file
@ -0,0 +1,288 @@
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, Request, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.background import BackgroundTasks
|
||||
from starlette.middleware import Middleware
|
||||
from starlette_context import context
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||
from utils.pr_agent.config_loader import get_settings, global_settings
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from utils.pr_agent.secret_providers import get_secret_provider
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
router = APIRouter()
|
||||
|
||||
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||
|
||||
|
||||
async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
|
||||
try:
|
||||
import requests
|
||||
headers = {
|
||||
'Private-Token': f'{gitlab_token}'
|
||||
}
|
||||
# API endpoint to find MRs containing the commit
|
||||
gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com')
|
||||
response = requests.get(
|
||||
f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests',
|
||||
headers=headers
|
||||
)
|
||||
merge_requests = response.json()
|
||||
if merge_requests and response.status_code == 200:
|
||||
pr_url = merge_requests[0]['web_url']
|
||||
return pr_url
|
||||
else:
|
||||
get_logger().info(f"No merge requests found for commit: {commit_sha}")
|
||||
return None
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get MR url from commit sha: {e}")
|
||||
return None
|
||||
|
||||
async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str):
|
||||
log_context["action"] = body
|
||||
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
||||
log_context["api_url"] = api_url
|
||||
log_context["app_name"] = get_settings().get("CONFIG.APP_NAME", "Unknown")
|
||||
|
||||
with get_logger().contextualize(**log_context):
|
||||
await PRAgent().handle_request(api_url, body)
|
||||
|
||||
|
||||
async def _perform_commands_gitlab(commands_conf: str, agent: PRAgent, api_url: str,
|
||||
log_context: dict, data: dict):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
|
||||
return
|
||||
if not should_process_pr_logic(data): # Here we already updated the configurations
|
||||
return
|
||||
commands = get_settings().get(f"gitlab.{commands_conf}", {})
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
for command in commands:
|
||||
try:
|
||||
split_command = command.split(" ")
|
||||
command = split_command[0]
|
||||
args = split_command[1:]
|
||||
other_args = update_settings_from_args(args)
|
||||
new_command = ' '.join([command] + other_args)
|
||||
get_logger().info(f"Performing command: {new_command}")
|
||||
with get_logger().contextualize(**log_context):
|
||||
await agent.handle_request(api_url, new_command)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to perform command {command}: {e}")
|
||||
|
||||
|
||||
def is_bot_user(data) -> bool:
|
||||
try:
|
||||
# logic to ignore bot users (unlike Github, no direct flag for bot users in gitlab)
|
||||
sender_name = data.get("user", {}).get("name", "unknown").lower()
|
||||
bot_indicators = ['codium', 'bot_', 'bot-', '_bot', '-bot']
|
||||
if any(indicator in sender_name for indicator in bot_indicators):
|
||||
get_logger().info(f"Skipping GitLab bot user: {sender_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'is_bot_user' logic: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def should_process_pr_logic(data) -> bool:
|
||||
try:
|
||||
if not data.get('object_attributes', {}):
|
||||
return False
|
||||
title = data['object_attributes'].get('title')
|
||||
sender = data.get("user", {}).get("username", "")
|
||||
|
||||
# logic to ignore PRs from specific users
|
||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||
if ignore_pr_users and sender:
|
||||
if sender in ignore_pr_users:
|
||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings")
|
||||
return False
|
||||
|
||||
# logic to ignore MRs for titles, labels and source, target branches.
|
||||
ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||
ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
|
||||
ignore_mr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
||||
ignore_mr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
||||
|
||||
#
|
||||
if ignore_mr_source_branches:
|
||||
source_branch = data['object_attributes'].get('source_branch')
|
||||
if any(re.search(regex, source_branch) for regex in ignore_mr_source_branches):
|
||||
get_logger().info(
|
||||
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings")
|
||||
return False
|
||||
|
||||
if ignore_mr_target_branches:
|
||||
target_branch = data['object_attributes'].get('target_branch')
|
||||
if any(re.search(regex, target_branch) for regex in ignore_mr_target_branches):
|
||||
get_logger().info(
|
||||
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings")
|
||||
return False
|
||||
|
||||
if ignore_mr_labels:
|
||||
labels = [label['title'] for label in data['object_attributes'].get('labels', [])]
|
||||
if any(label in ignore_mr_labels for label in labels):
|
||||
labels_str = ", ".join(labels)
|
||||
get_logger().info(f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings")
|
||||
return False
|
||||
|
||||
if ignore_mr_title:
|
||||
if any(re.search(regex, title) for regex in ignore_mr_title):
|
||||
get_logger().info(f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings")
|
||||
return False
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||
return True
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
start_time = datetime.now()
|
||||
request_json = await request.json()
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
|
||||
async def inner(data: dict):
|
||||
log_context = {"server_type": "gitlab_app"}
|
||||
get_logger().debug("Received a GitLab webhook")
|
||||
if request.headers.get("X-Gitlab-Token") and secret_provider:
|
||||
request_token = request.headers.get("X-Gitlab-Token")
|
||||
secret = secret_provider.get_secret(request_token)
|
||||
if not secret:
|
||||
get_logger().warning(f"Empty secret retrieved, request_token: {request_token}")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=jsonable_encoder({"message": "unauthorized"}))
|
||||
try:
|
||||
secret_dict = json.loads(secret)
|
||||
gitlab_token = secret_dict["gitlab_token"]
|
||||
log_context["token_id"] = secret_dict.get("token_name", secret_dict.get("id", "unknown"))
|
||||
context["settings"].gitlab.personal_access_token = gitlab_token
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to validate secret {request_token}: {e}")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
elif get_settings().get("GITLAB.SHARED_SECRET"):
|
||||
secret = get_settings().get("GITLAB.SHARED_SECRET")
|
||||
if not request.headers.get("X-Gitlab-Token") == secret:
|
||||
get_logger().error("Failed to validate secret")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
else:
|
||||
get_logger().error("Failed to validate secret")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||
if not gitlab_token:
|
||||
get_logger().error("No gitlab token found")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
|
||||
get_logger().info("GitLab data", artifact=data)
|
||||
sender = data.get("user", {}).get("username", "unknown")
|
||||
sender_id = data.get("user", {}).get("id", "unknown")
|
||||
|
||||
# ignore bot users
|
||||
if is_bot_user(data):
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
if data.get('event_type') != 'note': # not a comment
|
||||
# ignore MRs based on title, labels, source and target branches
|
||||
if not should_process_pr_logic(data):
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
|
||||
log_context["sender"] = sender
|
||||
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']:
|
||||
title = data['object_attributes'].get('title')
|
||||
url = data['object_attributes'].get('url')
|
||||
draft = data['object_attributes'].get('draft')
|
||||
get_logger().info(f"New merge request: {url}")
|
||||
if draft:
|
||||
get_logger().info(f"Skipping draft MR: {url}")
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
|
||||
await _perform_commands_gitlab("pr_commands", PRAgent(), url, log_context, data)
|
||||
elif data.get('object_kind') == 'note' and data.get('event_type') == 'note': # comment on MR
|
||||
if 'merge_request' in data:
|
||||
mr = data['merge_request']
|
||||
url = mr.get('url')
|
||||
|
||||
get_logger().info(f"A comment has been added to a merge request: {url}")
|
||||
body = data.get('object_attributes', {}).get('note')
|
||||
if data.get('object_attributes', {}).get('type') == 'DiffNote' and '/ask' in body: # /ask_line
|
||||
body = handle_ask_line(body, data)
|
||||
|
||||
await handle_request(url, body, log_context, sender_id)
|
||||
elif data.get('object_kind') == 'push' and data.get('event_name') == 'push':
|
||||
try:
|
||||
project_id = data['project_id']
|
||||
commit_sha = data['checkout_sha']
|
||||
url = await get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id)
|
||||
if not url:
|
||||
get_logger().info(f"No MR found for commit: {commit_sha}")
|
||||
return JSONResponse(status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}))
|
||||
|
||||
# we need first to apply_repo_settings
|
||||
apply_repo_settings(url)
|
||||
commands_on_push = get_settings().get(f"gitlab.push_commands", {})
|
||||
handle_push_trigger = get_settings().get(f"gitlab.handle_push_trigger", False)
|
||||
if not commands_on_push or not handle_push_trigger:
|
||||
get_logger().info("Push event, but no push commands found or push trigger is disabled")
|
||||
return JSONResponse(status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}))
|
||||
|
||||
get_logger().debug(f'A push event has been received: {url}')
|
||||
await _perform_commands_gitlab("push_commands", PRAgent(), url, log_context, data)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle push event: {e}")
|
||||
|
||||
background_tasks.add_task(inner, request_json)
|
||||
end_time = datetime.now()
|
||||
get_logger().info(f"Processing time: {end_time - start_time}", request=request_json)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
|
||||
|
||||
def handle_ask_line(body, data):
|
||||
try:
|
||||
line_range_ = data['object_attributes']['position']['line_range']
|
||||
# if line_range_['start']['type'] == 'new':
|
||||
start_line = line_range_['start']['new_line']
|
||||
end_line = line_range_['end']['new_line']
|
||||
# else:
|
||||
# start_line = line_range_['start']['old_line']
|
||||
# end_line = line_range_['end']['old_line']
|
||||
question = body.replace('/ask', '').strip()
|
||||
path = data['object_attributes']['position']['new_path']
|
||||
side = 'RIGHT' # if line_range_['start']['type'] == 'new' else 'LEFT'
|
||||
comment_id = data['object_attributes']["discussion_id"]
|
||||
get_logger().info("Handling line comment")
|
||||
body = f"/ask_line --line_start={start_line} --line_end={end_line} --side={side} --file_name={path} --comment_id={comment_id} {question}"
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle ask line comment: {e}")
|
||||
return body
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||
if not gitlab_url:
|
||||
raise ValueError("GITLAB.URL is not set")
|
||||
get_settings().config.git_provider = "gitlab"
|
||||
middleware = [Middleware(RawContextMiddleware)]
|
||||
app = FastAPI(middleware=middleware)
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
def start():
|
||||
uvicorn.run(app, host="0.0.0.0", port=3000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start()
|
||||
191
apps/utils/pr_agent/servers/gunicorn_config.py
Normal file
191
apps/utils/pr_agent/servers/gunicorn_config.py
Normal file
@ -0,0 +1,191 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
# from prometheus_client import multiprocess
|
||||
|
||||
# Sample Gunicorn configuration file.
|
||||
|
||||
#
|
||||
# Server socket
|
||||
#
|
||||
# bind - The socket to bind.
|
||||
#
|
||||
# A string of the form: 'HOST', 'HOST:PORT', 'unix:PATH'.
|
||||
# An IP is a valid HOST.
|
||||
#
|
||||
# backlog - The number of pending connections. This refers
|
||||
# to the number of clients that can be waiting to be
|
||||
# served. Exceeding this number results in the client
|
||||
# getting an error when attempting to connect. It should
|
||||
# only affect servers under significant load.
|
||||
#
|
||||
# Must be a positive integer. Generally set in the 64-2048
|
||||
# range.
|
||||
#
|
||||
|
||||
# bind = '0.0.0.0:5000'
|
||||
bind = '0.0.0.0:3000'
|
||||
backlog = 2048
|
||||
|
||||
#
|
||||
# Worker processes
|
||||
#
|
||||
# workers - The number of worker processes that this server
|
||||
# should keep alive for handling requests.
|
||||
#
|
||||
# A positive integer generally in the 2-4 x $(NUM_CORES)
|
||||
# range. You'll want to vary this a bit to find the best
|
||||
# for your particular application's work load.
|
||||
#
|
||||
# worker_class - The type of workers to use. The default
|
||||
# sync class should handle most 'normal' types of work
|
||||
# loads. You'll want to read
|
||||
# http://docs.gunicorn.org/en/latest/design.html#choosing-a-worker-type
|
||||
# for information on when you might want to choose one
|
||||
# of the other worker classes.
|
||||
#
|
||||
# A string referring to a Python path to a subclass of
|
||||
# gunicorn.workers.base.Worker. The default provided values
|
||||
# can be seen at
|
||||
# http://docs.gunicorn.org/en/latest/settings.html#worker-class
|
||||
#
|
||||
# worker_connections - For the eventlet and gevent worker classes
|
||||
# this limits the maximum number of simultaneous clients that
|
||||
# a single process can handle.
|
||||
#
|
||||
# A positive integer generally set to around 1000.
|
||||
#
|
||||
# timeout - If a worker does not notify the master process in this
|
||||
# number of seconds it is killed and a new worker is spawned
|
||||
# to replace it.
|
||||
#
|
||||
# Generally set to thirty seconds. Only set this noticeably
|
||||
# higher if you're sure of the repercussions for sync workers.
|
||||
# For the non sync workers it just means that the worker
|
||||
# process is still communicating and is not tied to the length
|
||||
# of time required to handle a single request.
|
||||
#
|
||||
# keepalive - The number of seconds to wait for the next request
|
||||
# on a Keep-Alive HTTP connection.
|
||||
#
|
||||
# A positive integer. Generally set in the 1-5 seconds range.
|
||||
#
|
||||
|
||||
if os.getenv('GUNICORN_WORKERS', None):
|
||||
workers = int(os.getenv('GUNICORN_WORKERS'))
|
||||
else:
|
||||
cores = multiprocessing.cpu_count()
|
||||
workers = cores * 2 + 1
|
||||
worker_connections = 1000
|
||||
timeout = 240
|
||||
keepalive = 2
|
||||
|
||||
#
|
||||
# spew - Install a trace function that spews every line of Python
|
||||
# that is executed when running the server. This is the
|
||||
# nuclear option.
|
||||
#
|
||||
# True or False
|
||||
#
|
||||
|
||||
spew = False
|
||||
|
||||
#
|
||||
# Server mechanics
|
||||
#
|
||||
# daemon - Detach the main Gunicorn process from the controlling
|
||||
# terminal with a standard fork/fork sequence.
|
||||
#
|
||||
# True or False
|
||||
#
|
||||
# raw_env - Pass environment variables to the execution environment.
|
||||
#
|
||||
# pidfile - The path to a pid file to write
|
||||
#
|
||||
# A path string or None to not write a pid file.
|
||||
#
|
||||
# user - Switch worker processes to run as this user.
|
||||
#
|
||||
# A valid user id (as an integer) or the name of a user that
|
||||
# can be retrieved with a call to pwd.getpwnam(value) or None
|
||||
# to not change the worker process user.
|
||||
#
|
||||
# group - Switch worker process to run as this group.
|
||||
#
|
||||
# A valid group id (as an integer) or the name of a user that
|
||||
# can be retrieved with a call to pwd.getgrnam(value) or None
|
||||
# to change the worker processes group.
|
||||
#
|
||||
# umask - A mask for file permissions written by Gunicorn. Note that
|
||||
# this affects unix socket permissions.
|
||||
#
|
||||
# A valid value for the os.umask(mode) call or a string
|
||||
# compatible with int(value, 0) (0 means Python guesses
|
||||
# the base, so values like "0", "0xFF", "0022" are valid
|
||||
# for decimal, hex, and octal representations)
|
||||
#
|
||||
# tmp_upload_dir - A directory to store temporary request data when
|
||||
# requests are read. This will most likely be disappearing soon.
|
||||
#
|
||||
# A path to a directory where the process owner can write. Or
|
||||
# None to signal that Python should choose one on its own.
|
||||
#
|
||||
|
||||
daemon = False
|
||||
raw_env = []
|
||||
pidfile = None
|
||||
umask = 0
|
||||
user = None
|
||||
group = None
|
||||
tmp_upload_dir = None
|
||||
|
||||
#
|
||||
# Logging
|
||||
#
|
||||
# logfile - The path to a log file to write to.
|
||||
#
|
||||
# A path string. "-" means log to stdout.
|
||||
#
|
||||
# loglevel - The granularity of log output
|
||||
#
|
||||
# A string of "debug", "info", "warning", "error", "critical"
|
||||
#
|
||||
|
||||
errorlog = '-'
|
||||
loglevel = 'info'
|
||||
accesslog = None
|
||||
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
|
||||
|
||||
#
|
||||
# Process naming
|
||||
#
|
||||
# proc_name - A base to use with setproctitle to change the way
|
||||
# that Gunicorn processes are reported in the system process
|
||||
# table. This affects things like 'ps' and 'top'. If you're
|
||||
# going to be running more than one instance of Gunicorn you'll
|
||||
# probably want to set a name to tell them apart. This requires
|
||||
# that you install the setproctitle module.
|
||||
#
|
||||
# A string or None to choose a default of something like 'gunicorn'.
|
||||
#
|
||||
|
||||
proc_name = None
|
||||
|
||||
|
||||
#
|
||||
# Server hooks
|
||||
#
|
||||
# post_fork - Called just after a worker has been forked.
|
||||
#
|
||||
# A callable that takes a server and worker instance
|
||||
# as arguments.
|
||||
#
|
||||
# pre_fork - Called just prior to forking the worker subprocess.
|
||||
#
|
||||
# A callable that accepts the same arguments as after_fork
|
||||
#
|
||||
# pre_exec - Called just prior to forking off a secondary
|
||||
# master process during things like config reloading.
|
||||
#
|
||||
# A callable that takes a server instance as the sole argument.
|
||||
#
|
||||
203
apps/utils/pr_agent/servers/help.py
Normal file
203
apps/utils/pr_agent/servers/help.py
Normal file
@ -0,0 +1,203 @@
|
||||
class HelpMessage:
|
||||
@staticmethod
|
||||
def get_general_commands_text():
|
||||
commands_text = "> - **/review**: Request a review of your Pull Request. \n" \
|
||||
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n" \
|
||||
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" \
|
||||
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n" \
|
||||
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n" \
|
||||
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" \
|
||||
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" \
|
||||
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" \
|
||||
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" \
|
||||
">To list the possible configuration parameters, add a **/config** comment. \n"
|
||||
return commands_text
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_general_bot_help_text():
|
||||
output = f"> To invoke the PR-Agent, add a comment using one of the following commands: \n{HelpMessage.get_general_commands_text()} \n"
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def get_review_usage_guide():
|
||||
output ="**Overview:**\n"
|
||||
output +=("The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
|
||||
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n")
|
||||
output +="""\
|
||||
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template:
|
||||
```
|
||||
/review --pr_reviewer.some_config1=... --pr_reviewer.some_config2=...
|
||||
```
|
||||
- With a [configuration file](https://pr-agent-docs.codium.ai/usage-guide/configuration_options/), use the following template:
|
||||
```
|
||||
[pr_reviewer]
|
||||
some_config1=...
|
||||
some_config2=...
|
||||
```
|
||||
"""
|
||||
|
||||
output += f"\n\nSee the review [usage page](https://pr-agent-docs.codium.ai/tools/review/) for a comprehensive guide on using this tool.\n\n"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_describe_usage_guide():
|
||||
output = "**Overview:**\n"
|
||||
output += "The `describe` tool scans the PR code changes, and generates a description for the PR - title, type, summary, walkthrough and labels. "
|
||||
output += "The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on a PR.\n"
|
||||
output += """\
|
||||
|
||||
When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L46) related to the describe tool (`pr_description` section), use the following template:
|
||||
```
|
||||
/describe --pr_description.some_config1=... --pr_description.some_config2=...
|
||||
```
|
||||
With a [configuration file](https://pr-agent-docs.codium.ai/usage-guide/configuration_options/), use the following template:
|
||||
```
|
||||
[pr_description]
|
||||
some_config1=...
|
||||
some_config2=...
|
||||
```
|
||||
"""
|
||||
output += "\n\n<table>"
|
||||
|
||||
# automation
|
||||
output += "<tr><td><details> <summary><strong> Enabling\\disabling automation </strong></summary><hr>\n\n"
|
||||
output += """\
|
||||
- When you first install the app, the [default mode](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) for the describe tool is:
|
||||
```
|
||||
pr_commands = ["/describe", ...]
|
||||
```
|
||||
meaning the `describe` tool will run automatically on every PR.
|
||||
|
||||
- Markers are an alternative way to control the generated description, to give maximal control to the user. If you set:
|
||||
```
|
||||
pr_commands = ["/describe --pr_description.use_description_markers=true", ...]
|
||||
```
|
||||
the tool will replace every marker of the form `pr_agent:marker_name` in the PR description with the relevant content, where `marker_name` is one of the following:
|
||||
- `type`: the PR type.
|
||||
- `summary`: the PR summary.
|
||||
- `walkthrough`: the PR walkthrough.
|
||||
|
||||
Note that when markers are enabled, if the original PR description does not contain any markers, the tool will not alter the description at all.
|
||||
|
||||
"""
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
# custom labels
|
||||
output += "<tr><td><details> <summary><strong> Custom labels </strong></summary><hr>\n\n"
|
||||
output += """\
|
||||
The default labels of the `describe` tool are quite generic: [`Bug fix`, `Tests`, `Enhancement`, `Documentation`, `Other`].
|
||||
|
||||
If you specify [custom labels](https://pr-agent-docs.codium.ai/tools/describe/#handle-custom-labels-from-the-repos-labels-page) in the repo's labels page or via configuration file, you can get tailored labels for your use cases.
|
||||
Examples for custom labels:
|
||||
- `Main topic:performance` - pr_agent:The main topic of this PR is performance
|
||||
- `New endpoint` - pr_agent:A new endpoint was added in this PR
|
||||
- `SQL query` - pr_agent:A new SQL query was added in this PR
|
||||
- `Dockerfile changes` - pr_agent:The PR contains changes in the Dockerfile
|
||||
- ...
|
||||
|
||||
The list above is eclectic, and aims to give an idea of different possibilities. Define custom labels that are relevant for your repo and use cases.
|
||||
Note that Labels are not mutually exclusive, so you can add multiple label categories.
|
||||
Make sure to provide proper title, and a detailed and well-phrased description for each label, so the tool will know when to suggest it.
|
||||
"""
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
# Inline File Walkthrough
|
||||
output += "<tr><td><details> <summary><strong> Inline File Walkthrough 💎</strong></summary><hr>\n\n"
|
||||
output += """\
|
||||
For enhanced user experience, the `describe` tool can add file summaries directly to the "Files changed" tab in the PR page.
|
||||
This will enable you to quickly understand the changes in each file, while reviewing the code changes (diffs).
|
||||
|
||||
To enable inline file summary, set `pr_description.inline_file_summary` in the configuration file, possible values are:
|
||||
- `'table'`: File changes walkthrough table will be displayed on the top of the "Files changed" tab, in addition to the "Conversation" tab.
|
||||
- `true`: A collapsable file comment with changes title and a changes summary for each file in the PR.
|
||||
- `false` (default): File changes walkthrough will be added only to the "Conversation" tab.
|
||||
"""
|
||||
|
||||
# extra instructions
|
||||
output += "<tr><td><details> <summary><strong> Utilizing extra instructions</strong></summary><hr>\n\n"
|
||||
output += '''\
|
||||
The `describe` tool can be configured with extra instructions, to guide the model to a feedback tailored to the needs of your project.
|
||||
|
||||
Be specific, clear, and concise in the instructions. With extra instructions, you are the prompter. Notice that the general structure of the description is fixed, and cannot be changed. Extra instructions can change the content or style of each sub-section of the PR description.
|
||||
|
||||
Examples for extra instructions:
|
||||
```
|
||||
[pr_description]
|
||||
extra_instructions="""\
|
||||
- The PR title should be in the format: '<PR type>: <title>'
|
||||
- The title should be short and concise (up to 10 words)
|
||||
- ...
|
||||
"""
|
||||
```
|
||||
Use triple quotes to write multi-line instructions. Use bullet points to make the instructions more readable.
|
||||
'''
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
|
||||
# general
|
||||
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
|
||||
output += HelpMessage.get_general_bot_help_text()
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
output += "</table>"
|
||||
|
||||
output += f"\n\nSee the [describe usage](https://pr-agent-docs.codium.ai/tools/describe/) page for a comprehensive guide on using this tool.\n\n"
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def get_ask_usage_guide():
|
||||
output = "**Overview:**\n"
|
||||
output += """\
|
||||
The `ask` tool answers questions about the PR, based on the PR code changes.
|
||||
It can be invoked manually by commenting on any PR:
|
||||
```
|
||||
/ask "..."
|
||||
```
|
||||
|
||||
Note that the tool does not have "memory" of previous questions, and answers each question independently.
|
||||
You can ask questions about the entire PR, about specific code lines, or about an image related to the PR code changes.
|
||||
"""
|
||||
# output += "\n\n<table>"
|
||||
#
|
||||
# # # general
|
||||
# # output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
|
||||
# # output += HelpMessage.get_general_bot_help_text()
|
||||
# # output += "\n\n</details></td></tr>\n\n"
|
||||
#
|
||||
# output += "</table>"
|
||||
|
||||
output += f"\n\nSee the [ask usage](https://pr-agent-docs.codium.ai/tools/ask/) page for a comprehensive guide on using this tool.\n\n"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_improve_usage_guide():
|
||||
output = "**Overview:**\n"
|
||||
output += "The code suggestions tool, named `improve`, scans the PR code changes, and automatically generates code suggestions for improving the PR."
|
||||
output += "The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on a PR.\n"
|
||||
output += """\
|
||||
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L78) related to the improve tool (`pr_code_suggestions` section), use the following template:
|
||||
|
||||
```
|
||||
/improve --pr_code_suggestions.some_config1=... --pr_code_suggestions.some_config2=...
|
||||
```
|
||||
|
||||
- With a [configuration file](https://pr-agent-docs.codium.ai/usage-guide/configuration_options/), use the following template:
|
||||
|
||||
```
|
||||
[pr_code_suggestions]
|
||||
some_config1=...
|
||||
some_config2=...
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
output += f"\n\nSee the improve [usage page](https://pr-agent-docs.codium.ai/tools/improve/) for a comprehensive guide on using this tool.\n\n"
|
||||
|
||||
return output
|
||||
16
apps/utils/pr_agent/servers/serverless.py
Normal file
16
apps/utils/pr_agent/servers/serverless.py
Normal file
@ -0,0 +1,16 @@
|
||||
from fastapi import FastAPI
|
||||
from mangum import Mangum
|
||||
from starlette.middleware import Middleware
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from utils.pr_agent.servers.github_app import router
|
||||
|
||||
middleware = [Middleware(RawContextMiddleware)]
|
||||
app = FastAPI(middleware=middleware)
|
||||
app.include_router(router)
|
||||
|
||||
handler = Mangum(app, lifespan="off")
|
||||
|
||||
|
||||
def serverless(event, context):
|
||||
return handler(event, context)
|
||||
86
apps/utils/pr_agent/servers/utils.py
Normal file
86
apps/utils/pr_agent/servers/utils.py
Normal file
@ -0,0 +1,86 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def verify_signature(payload_body, secret_token, signature_header):
|
||||
"""Verify that the payload was sent from GitHub by validating SHA256.
|
||||
|
||||
Raise and return 403 if not authorized.
|
||||
|
||||
Args:
|
||||
payload_body: original request body to verify (request.body())
|
||||
secret_token: GitHub app webhook token (WEBHOOK_SECRET)
|
||||
signature_header: header received from GitHub (x-hub-signature-256)
|
||||
"""
|
||||
if not signature_header:
|
||||
raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!")
|
||||
hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256)
|
||||
expected_signature = "sha256=" + hash_object.hexdigest()
|
||||
if not hmac.compare_digest(expected_signature, signature_header):
|
||||
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when the git provider API rate limit has been exceeded."""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultDictWithTimeout(defaultdict):
|
||||
"""A defaultdict with a time-to-live (TTL)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_factory: Callable[[], Any] = None,
|
||||
ttl: int = None,
|
||||
refresh_interval: int = 60,
|
||||
update_key_time_on_get: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
default_factory: The default factory to use for keys that are not in the dictionary.
|
||||
ttl: The time-to-live (TTL) in seconds.
|
||||
refresh_interval: How often to refresh the dict and delete items older than the TTL.
|
||||
update_key_time_on_get: Whether to update the access time of a key also on get (or only when set).
|
||||
"""
|
||||
super().__init__(default_factory, *args, **kwargs)
|
||||
self.__key_times = dict()
|
||||
self.__ttl = ttl
|
||||
self.__refresh_interval = refresh_interval
|
||||
self.__update_key_time_on_get = update_key_time_on_get
|
||||
self.__last_refresh = self.__time() - self.__refresh_interval
|
||||
|
||||
@staticmethod
|
||||
def __time():
|
||||
return time.monotonic()
|
||||
|
||||
def __refresh(self):
|
||||
if self.__ttl is None:
|
||||
return
|
||||
request_time = self.__time()
|
||||
if request_time - self.__last_refresh > self.__refresh_interval:
|
||||
return
|
||||
to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl]
|
||||
for key in to_delete:
|
||||
del self[key]
|
||||
self.__last_refresh = request_time
|
||||
|
||||
def __getitem__(self, __key):
|
||||
if self.__update_key_time_on_get:
|
||||
self.__key_times[__key] = self.__time()
|
||||
self.__refresh()
|
||||
return super().__getitem__(__key)
|
||||
|
||||
def __setitem__(self, __key, __value):
|
||||
self.__key_times[__key] = self.__time()
|
||||
return super().__setitem__(__key, __value)
|
||||
|
||||
def __delitem__(self, __key):
|
||||
del self.__key_times[__key]
|
||||
return super().__delitem__(__key)
|
||||
96
apps/utils/pr_agent/settings/.secrets_template.toml
Normal file
96
apps/utils/pr_agent/settings/.secrets_template.toml
Normal file
@ -0,0 +1,96 @@
|
||||
# QUICKSTART:
|
||||
# Copy this file to .secrets.toml in the same folder.
|
||||
# The minimum workable settings - set openai.key to your API key.
|
||||
# Set github.deployment_type to "user" and github.user_token to your GitHub personal access token.
|
||||
# This will allow you to run the CLI scripts in the scripts/ folder and the github_polling server.
|
||||
#
|
||||
# See README for details about GitHub App deployment.
|
||||
|
||||
[openai]
|
||||
key = "" # Acquire through https://platform.openai.com
|
||||
#org = "<ORGANIZATION>" # Optional, may be commented out.
|
||||
# Uncomment the following for Azure OpenAI
|
||||
#api_type = "azure"
|
||||
#api_version = '2023-05-15' # Check Azure documentation for the current API version
|
||||
#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
|
||||
#deployment_id = "" # The deployment name you chose when you deployed the engine
|
||||
#fallback_deployments = [] # For each fallback model specified in configuration.toml in the [config] section, specify the appropriate deployment_id
|
||||
|
||||
[pinecone]
|
||||
api_key = "..."
|
||||
environment = "gcp-starter"
|
||||
|
||||
[anthropic]
|
||||
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
|
||||
|
||||
[cohere]
|
||||
key = "" # Optional, uncomment if you want to use Cohere. Acquire through https://dashboard.cohere.ai/
|
||||
|
||||
[replicate]
|
||||
key = "" # Optional, uncomment if you want to use Replicate. Acquire through https://replicate.com/
|
||||
|
||||
[groq]
|
||||
key = "" # Acquire through https://console.groq.com/keys
|
||||
|
||||
[huggingface]
|
||||
key = "" # Optional, uncomment if you want to use Huggingface Inference API. Acquire through https://huggingface.co/docs/api-inference/quicktour
|
||||
api_base = "" # the base url for your huggingface inference endpoint
|
||||
|
||||
[ollama]
|
||||
api_base = "" # the base url for your local Llama 2, Code Llama, and other models inference endpoint. Acquire through https://ollama.ai/
|
||||
|
||||
[vertexai]
|
||||
vertex_project = "" # the google cloud platform project name for your vertexai deployment
|
||||
vertex_location = "" # the google cloud platform location for your vertexai deployment
|
||||
|
||||
[google_ai_studio]
|
||||
gemini_api_key = "" # the google AI Studio API key
|
||||
|
||||
[github]
|
||||
# ---- Set the following only for deployment type == "user"
|
||||
user_token = "" # A GitHub personal access token with 'repo' scope.
|
||||
deployment_type = "user" #set to user by default
|
||||
|
||||
# ---- Set the following only for deployment type == "app", see README for details.
|
||||
private_key = """\
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
<GITHUB PRIVATE KEY>
|
||||
-----END RSA PRIVATE KEY-----
|
||||
"""
|
||||
app_id = 123456 # The GitHub App ID, replace with your own.
|
||||
webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.
|
||||
|
||||
[gitlab]
|
||||
# Gitlab personal access token
|
||||
personal_access_token = ""
|
||||
shared_secret = "" # webhook secret
|
||||
|
||||
[bitbucket]
|
||||
# For Bitbucket personal/repository bearer token
|
||||
bearer_token = ""
|
||||
|
||||
[bitbucket_server]
|
||||
# For Bitbucket Server bearer token
|
||||
bearer_token = ""
|
||||
webhook_secret = ""
|
||||
|
||||
# For Bitbucket app
|
||||
app_key = ""
|
||||
base_url = ""
|
||||
|
||||
[litellm]
|
||||
LITELLM_TOKEN = "" # see https://docs.litellm.ai/docs/debugging/hosted_debugging for details and instructions on how to get a token
|
||||
|
||||
[azure_devops]
|
||||
# For Azure devops personal access token
|
||||
org = ""
|
||||
pat = ""
|
||||
|
||||
[azure_devops_server]
|
||||
# For Azure devops Server basic auth - configured in the webhook creation
|
||||
# Optional, uncomment if you want to use Azure devops webhooks. Value assinged when you create the webhook
|
||||
# webhook_username = "<basic auth user>"
|
||||
# webhook_password = "<basic auth password>"
|
||||
|
||||
[deepseek]
|
||||
key = ""
|
||||
333
apps/utils/pr_agent/settings/configuration.toml
Normal file
333
apps/utils/pr_agent/settings/configuration.toml
Normal file
@ -0,0 +1,333 @@
|
||||
[config]
|
||||
# models
|
||||
model="o3-mini"
|
||||
fallback_models=["o3-mini"]
|
||||
# model_weak="gpt-4o-mini" # optional, a weaker model to use for some easier tasks
|
||||
# CLI
|
||||
git_provider="gitlab"
|
||||
publish_output=true
|
||||
publish_output_progress=true
|
||||
publish_output_no_suggestions=true
|
||||
verbosity_level=0 # 0,1,2
|
||||
use_extra_bad_extensions=false
|
||||
# Configurations
|
||||
use_wiki_settings_file=true
|
||||
use_repo_settings_file=true
|
||||
use_global_settings_file=true
|
||||
disable_auto_feedback = false
|
||||
ai_timeout=120 # 2minutes
|
||||
skip_keys = []
|
||||
custom_reasoning_model = true # when true, disables system messages and temperature controls for models that don't support chat-style inputs
|
||||
# token limits
|
||||
max_description_tokens = 500
|
||||
max_commits_tokens = 500
|
||||
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
||||
custom_model_max_tokens=-1 # for models not in the default list
|
||||
# patch extension logic
|
||||
patch_extension_skip_types =[".md",".txt"]
|
||||
allow_dynamic_context=true
|
||||
max_extra_lines_before_dynamic_context = 8 # will try to include up to 10 extra lines before the hunk in the patch, until we reach an enclosing function or class
|
||||
patch_extra_lines_before = 3 # Number of extra lines (+3 default ones) to include before each hunk in the patch
|
||||
patch_extra_lines_after = 1 # Number of extra lines (+3 default ones) to include after each hunk in the patch
|
||||
secret_provider=""
|
||||
cli_mode=false
|
||||
ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs
|
||||
ai_disclaimer="" # Pro feature, full text for the AI disclaimer
|
||||
output_relevant_configurations=false
|
||||
large_patch_policy = "clip" # "clip", "skip"
|
||||
duplicate_prompt_examples = false
|
||||
# seed
|
||||
seed=-1 # set positive value to fix the seed (and ensure temperature=0)
|
||||
temperature=0.2
|
||||
# ignore logic
|
||||
ignore_pr_title = ["^\\[Auto\\]", "^Auto"] # a list of regular expressions to match against the PR title to ignore the PR agent
|
||||
ignore_pr_target_branches = [] # a list of regular expressions of target branches to ignore from PR agent when an PR is created
|
||||
ignore_pr_source_branches = [] # a list of regular expressions of source branches to ignore from PR agent when an PR is created
|
||||
ignore_pr_labels = [] # labels to ignore from PR agent when an PR is created
|
||||
ignore_pr_authors = [] # authors to ignore from PR agent when an PR is created
|
||||
#
|
||||
is_auto_command = false # will be auto-set to true if the command is triggered by an automation
|
||||
enable_ai_metadata = false # will enable adding ai metadata
|
||||
# auto approval 💎
|
||||
enable_auto_approval=false # Set to true to enable auto-approval of PRs under certain conditions
|
||||
auto_approve_for_low_review_effort=-1 # -1 to disable, [1-5] to set the threshold for auto-approval
|
||||
auto_approve_for_no_suggestions=false # If true, the PR will be auto-approved if there are no suggestions
|
||||
|
||||
|
||||
[pr_reviewer] # /review #
|
||||
# enable/disable features
|
||||
require_score_review=false
|
||||
require_tests_review=true
|
||||
require_estimate_effort_to_review=true
|
||||
require_can_be_split_review=false
|
||||
require_security_review=true
|
||||
require_ticket_analysis_review=true
|
||||
# general options
|
||||
persistent_comment=true
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
|
||||
final_update_message = true
|
||||
# review labels
|
||||
enable_review_labels_security=true
|
||||
enable_review_labels_effort=true
|
||||
# specific configurations for incremental review (/review -i)
|
||||
require_all_thresholds_for_incremental_review=false
|
||||
minimal_commits_for_incremental_review=0
|
||||
minimal_minutes_for_incremental_review=0
|
||||
enable_intro_text=true
|
||||
enable_help_text=false # Determines whether to include help text in the PR review. Enabled by default.
|
||||
|
||||
[pr_description] # /describe #
|
||||
publish_labels=false
|
||||
add_original_user_description=true
|
||||
generate_ai_title=false
|
||||
use_bullet_points=true
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
|
||||
enable_pr_type=true
|
||||
final_update_message = true
|
||||
enable_help_text=false
|
||||
enable_help_comment=true
|
||||
# describe as comment
|
||||
publish_description_as_comment=false
|
||||
publish_description_as_comment_persistent=true
|
||||
## changes walkthrough section
|
||||
enable_semantic_files_types=true
|
||||
collapsible_file_list='adaptive' # true, false, 'adaptive'
|
||||
collapsible_file_list_threshold=8
|
||||
inline_file_summary=false # false, true, 'table'
|
||||
# markers
|
||||
use_description_markers=false
|
||||
include_generated_by_header=true
|
||||
# large pr mode 💎
|
||||
enable_large_pr_handling=true
|
||||
max_ai_calls=4
|
||||
async_ai_calls=true
|
||||
#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']
|
||||
|
||||
[pr_questions] # /ask #
|
||||
enable_help_text=false
|
||||
|
||||
|
||||
[pr_code_suggestions] # /improve #
|
||||
max_context_tokens=16000
|
||||
#
|
||||
commitable_code_suggestions = false
|
||||
dual_publishing_score_threshold=-1 # -1 to disable, [0-10] to set the threshold (>=) for publishing a code suggestion both in a table and as commitable
|
||||
focus_only_on_problems=true
|
||||
#
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
|
||||
enable_help_text=false
|
||||
enable_chat_text=false
|
||||
enable_intro_text=true
|
||||
persistent_comment=true
|
||||
max_history_len=4
|
||||
# enable to apply suggestion 💎
|
||||
apply_suggestions_checkbox=true
|
||||
# suggestions scoring
|
||||
suggestions_score_threshold=0 # [0-10]| recommend not to set this value above 8, since above it may clip highly relevant suggestions
|
||||
new_score_mechanism=true
|
||||
new_score_mechanism_th_high=9
|
||||
new_score_mechanism_th_medium=7
|
||||
# params for '/improve --extended' mode
|
||||
auto_extended_mode=true
|
||||
num_code_suggestions_per_chunk=4
|
||||
max_number_of_calls = 3
|
||||
parallel_calls = true
|
||||
|
||||
final_clip_factor = 0.8
|
||||
# self-review checkbox
|
||||
demand_code_suggestions_self_review=false # add a checkbox for the author to self-review the code suggestions
|
||||
code_suggestions_self_review_text= "**作者自我审查**:我已审核PR代码建议,并处理了相关内容."
|
||||
approve_pr_on_self_review=false # Pro feature. if true, the PR will be auto-approved after the author clicks on the self-review checkbox
|
||||
fold_suggestions_on_self_review=true # Pro feature. if true, the code suggestions will be folded after the author clicks on the self-review checkbox
|
||||
# Suggestion impact 💎
|
||||
publish_post_process_suggestion_impact=true
|
||||
wiki_page_accepted_suggestions=true
|
||||
allow_thumbs_up_down=false
|
||||
|
||||
[pr_custom_prompt] # /custom_prompt #
|
||||
prompt = """\
|
||||
The code suggestions should focus only on the following:
|
||||
- ...
|
||||
- ...
|
||||
...
|
||||
"""
|
||||
suggestions_score_threshold=0
|
||||
num_code_suggestions_per_chunk=4
|
||||
self_reflect_on_custom_suggestions=true
|
||||
enable_help_text=false
|
||||
|
||||
|
||||
[pr_add_docs] # /add_docs #
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
|
||||
docs_style = "Sphinx" # "Google Style with Args, Returns, Attributes...etc", "Numpy Style", "Sphinx Style", "PEP257", "reStructuredText"
|
||||
file = "" # in case there are several components with the same name, you can specify the relevant file
|
||||
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
|
||||
|
||||
[pr_update_changelog] # /update_changelog #
|
||||
push_changelog_changes=false
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
|
||||
add_pr_link=true
|
||||
|
||||
[pr_analyze] # /analyze #
|
||||
enable_help_text=true
|
||||
|
||||
[pr_test] # /test #
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
|
||||
testing_framework = "" # specify the testing framework you want to use
|
||||
num_tests=3 # number of tests to generate. max 5.
|
||||
avoid_mocks=true # if true, the generated tests will prefer to use real objects instead of mocks
|
||||
file = "" # in case there are several components with the same name, you can specify the relevant file
|
||||
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
|
||||
enable_help_text=false
|
||||
|
||||
[pr_improve_component] # /improve_component #
|
||||
num_code_suggestions=4
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号"
|
||||
file = "" # in case there are several components with the same name, you can specify the relevant file
|
||||
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
|
||||
|
||||
[checks] # /checks (pro feature) #
|
||||
enable_auto_checks_feedback=true
|
||||
excluded_checks_list=["lint"] # list of checks to exclude, for example: ["check1", "check2"]
|
||||
persistent_comment=true
|
||||
enable_help_text=true
|
||||
final_update_message = false
|
||||
|
||||
[pr_help] # /help #
|
||||
force_local_db=false
|
||||
num_retrieved_snippets=5
|
||||
|
||||
[pr_config] # /config #
|
||||
|
||||
[github]
|
||||
# The type of deployment to create. Valid values are 'app' or 'user'.
|
||||
deployment_type = "user"
|
||||
ratelimit_retries = 5
|
||||
base_url = "https://api.github.com"
|
||||
publish_inline_comments_fallback_with_verification = true
|
||||
try_fix_invalid_inline_comments = true
|
||||
app_name = "pr-agent"
|
||||
ignore_bot_pr = true
|
||||
|
||||
[github_action_config]
|
||||
# auto_review = true # set as env var in .github/workflows/pr-agent.yaml
|
||||
# auto_describe = true # set as env var in .github/workflows/pr-agent.yaml
|
||||
# auto_improve = true # set as env var in .github/workflows/pr-agent.yaml
|
||||
# pr_actions = ['opened', 'reopened', 'ready_for_review', 'review_requested']
|
||||
|
||||
[github_app]
|
||||
# these toggles allows running the github app from custom deployments
|
||||
bot_user = "github-actions[bot]"
|
||||
override_deployment_type = true
|
||||
# settings for "pull_request" event
|
||||
handle_pr_actions = ['opened', 'reopened', 'ready_for_review']
|
||||
pr_commands = [
|
||||
"/describe --pr_description.final_update_message=false",
|
||||
"/review",
|
||||
"/improve",
|
||||
]
|
||||
# settings for "pull_request" event with "synchronize" action - used to detect and handle push triggers for new commits
|
||||
handle_push_trigger = false
|
||||
push_trigger_ignore_bot_commits = true
|
||||
push_trigger_ignore_merge_commits = true
|
||||
push_trigger_wait_for_initial_review = true
|
||||
push_trigger_pending_tasks_backlog = true
|
||||
push_trigger_pending_tasks_ttl = 300
|
||||
push_commands = [
|
||||
"/describe",
|
||||
"/review",
|
||||
]
|
||||
|
||||
[gitlab]
|
||||
url = "http://192.168.1.91:4010"
|
||||
pr_commands = [
|
||||
"/describe --pr_description.final_update_message=false",
|
||||
"/review",
|
||||
"/improve",
|
||||
]
|
||||
handle_push_trigger = true
|
||||
push_commands = [
|
||||
"/describe",
|
||||
"/review",
|
||||
]
|
||||
|
||||
[bitbucket_app]
|
||||
pr_commands = [
|
||||
"/describe --pr_description.final_update_message=false",
|
||||
"/review",
|
||||
"/improve --pr_code_suggestions.commitable_code_suggestions=true",
|
||||
]
|
||||
avoid_full_files = false
|
||||
|
||||
[local]
|
||||
# LocalGitProvider settings - uncomment to use paths other than default
|
||||
# description_path= "path/to/description.md"
|
||||
# review_path= "path/to/review.md"
|
||||
|
||||
[gerrit]
|
||||
# endpoint to the gerrit service
|
||||
# url = "ssh://gerrit.example.com:29418"
|
||||
# user for gerrit authentication
|
||||
# user = "ai-reviewer"
|
||||
# patch server where patches will be saved
|
||||
# patch_server_endpoint = "http://127.0.0.1:5000/patch"
|
||||
# token to authenticate in the patch server
|
||||
# patch_server_token = ""
|
||||
|
||||
[bitbucket_server]
|
||||
# URL to the BitBucket Server instance
|
||||
# url = "https://git.bitbucket.com"
|
||||
url = ""
|
||||
pr_commands = [
|
||||
"/describe --pr_description.final_update_message=false",
|
||||
"/review",
|
||||
"/improve --pr_code_suggestions.commitable_code_suggestions=true",
|
||||
]
|
||||
|
||||
[litellm]
|
||||
# use_client = false
|
||||
# drop_params = false
|
||||
enable_callbacks = false
|
||||
success_callback = []
|
||||
failure_callback = []
|
||||
service_callback = []
|
||||
|
||||
[pr_similar_issue]
|
||||
skip_comments = false
|
||||
force_update_dataset = false
|
||||
max_issues_to_scan = 500
|
||||
vectordb = "pinecone"
|
||||
|
||||
[pr_find_similar_component]
|
||||
class_name = ""
|
||||
file = ""
|
||||
search_from_org = false
|
||||
allow_fallback_less_words = true
|
||||
number_of_keywords = 5
|
||||
number_of_results = 5
|
||||
|
||||
[pinecone]
|
||||
# fill and place in .secrets.toml
|
||||
#api_key = ...
|
||||
# environment = "gcp-starter"
|
||||
|
||||
[lancedb]
|
||||
uri = "./lancedb"
|
||||
|
||||
[best_practices]
|
||||
content = ""
|
||||
organization_name = ""
|
||||
max_lines_allowed = 800
|
||||
enable_global_best_practices = false
|
||||
|
||||
[auto_best_practices]
|
||||
enable_auto_best_practices = true # public - general flag to disable all auto best practices usage
|
||||
utilize_auto_best_practices = true # public - disable usage of auto best practices in the 'improve' tool
|
||||
extra_instructions = "回答必须使用简体中文,并且必须使用英文标点符号" # public - extra instructions to the auto best practices generation prompt
|
||||
content = ""
|
||||
max_patterns = 5 # max number of patterns to be detected
|
||||
[openai]
|
||||
api_base = "http://110.40.24.85:3000/v1"
|
||||
[llm]
|
||||
model = "o3-mini"
|
||||
force_config = true
|
||||
16
apps/utils/pr_agent/settings/custom_labels.toml
Normal file
16
apps/utils/pr_agent/settings/custom_labels.toml
Normal file
@ -0,0 +1,16 @@
|
||||
[config]
|
||||
enable_custom_labels=false
|
||||
|
||||
## template for custom labels
|
||||
#[custom_labels."Bug fix"]
|
||||
#description = """Fixes a bug in the code"""
|
||||
#[custom_labels."Tests"]
|
||||
#description = """Adds or modifies tests"""
|
||||
#[custom_labels."Bug fix with tests"]
|
||||
#description = """Fixes a bug in the code and adds or modifies tests"""
|
||||
#[custom_labels."Enhancement"]
|
||||
#description = """Adds new features or modifies existing ones"""
|
||||
#[custom_labels."Documentation"]
|
||||
#description = """Adds or modifies documentation"""
|
||||
#[custom_labels."Other"]
|
||||
#description = """Other changes that do not fit in any of the above categories"""
|
||||
12
apps/utils/pr_agent/settings/ignore.toml
Normal file
12
apps/utils/pr_agent/settings/ignore.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[ignore]
|
||||
|
||||
glob = [
|
||||
# Ignore files and directories matching these glob patterns.
|
||||
# See https://docs.python.org/3/library/glob.html
|
||||
'vendor/**',
|
||||
]
|
||||
regex = [
|
||||
# Ignore files and directories matching these regex patterns.
|
||||
# See https://learnbyexample.github.io/python-regex-cheatsheet/
|
||||
# for example: regex = ['.*\.toml$']
|
||||
]
|
||||
440
apps/utils/pr_agent/settings/language_extensions.toml
Normal file
440
apps/utils/pr_agent/settings/language_extensions.toml
Normal file
@ -0,0 +1,440 @@
|
||||
[bad_extensions]
|
||||
default = [
|
||||
'app',
|
||||
'bin',
|
||||
'bmp',
|
||||
'bz2',
|
||||
'class',
|
||||
'csv',
|
||||
'dat',
|
||||
'db',
|
||||
'dll',
|
||||
'dylib',
|
||||
'egg',
|
||||
'eot',
|
||||
'exe',
|
||||
'gif',
|
||||
'gitignore',
|
||||
'glif',
|
||||
'gradle',
|
||||
'gz',
|
||||
'ico',
|
||||
'jar',
|
||||
'jpeg',
|
||||
'jpg',
|
||||
'lo',
|
||||
'lock',
|
||||
'log',
|
||||
'mp3',
|
||||
'mp4',
|
||||
'nar',
|
||||
'o',
|
||||
'ogg',
|
||||
'otf',
|
||||
'p',
|
||||
'pdf',
|
||||
'png',
|
||||
'pickle',
|
||||
'pkl',
|
||||
'pyc',
|
||||
'pyd',
|
||||
'pyo',
|
||||
'rkt',
|
||||
'so',
|
||||
'ss',
|
||||
'svg',
|
||||
'tar',
|
||||
'tgz',
|
||||
'tsv',
|
||||
'ttf',
|
||||
'war',
|
||||
'webm',
|
||||
'woff',
|
||||
'woff2',
|
||||
'xz',
|
||||
'zip',
|
||||
'zst',
|
||||
'snap',
|
||||
'lockb'
|
||||
]
|
||||
extra = [
|
||||
'md',
|
||||
'txt'
|
||||
]
|
||||
|
||||
[language_extension_map_org]
|
||||
"1C Enterprise" = ["*.bsl", ]
|
||||
ABAP = [".abap", ]
|
||||
"AGS Script" = [".ash", ]
|
||||
AMPL = [".ampl", ]
|
||||
ANTLR = [".g4", ]
|
||||
"API Blueprint" = [".apib", ]
|
||||
APL = [".apl", ".dyalog", ]
|
||||
ASP = [".asp", ".asax", ".ascx", ".ashx", ".asmx", ".aspx", ".axd", ]
|
||||
ATS = [".dats", ".hats", ".sats", ]
|
||||
ActionScript = [".as", ]
|
||||
Ada = [".adb", ".ada", ".ads", ]
|
||||
Agda = [".agda", ]
|
||||
Alloy = [".als", ]
|
||||
ApacheConf = [".apacheconf", ".vhost", ]
|
||||
AppleScript = [".applescript", ".scpt", ]
|
||||
Arc = [".arc", ]
|
||||
Arduino = [".ino", ]
|
||||
AsciiDoc = [".asciidoc", ".adoc", ]
|
||||
AspectJ = [".aj", ]
|
||||
Assembly = [".asm", ".a51", ".nasm", ]
|
||||
Augeas = [".aug", ]
|
||||
AutoHotkey = [".ahk", ".ahkl", ]
|
||||
AutoIt = [".au3", ]
|
||||
Awk = [".awk", ".auk", ".gawk", ".mawk", ".nawk", ]
|
||||
Batchfile = [".bat", ".cmd", ]
|
||||
Befunge = [".befunge", ]
|
||||
Bison = [".bison", ]
|
||||
BitBake = [".bb", ]
|
||||
BlitzBasic = [".decls", ]
|
||||
BlitzMax = [".bmx", ]
|
||||
Bluespec = [".bsv", ]
|
||||
Boo = [".boo", ]
|
||||
Brainfuck = [".bf", ]
|
||||
Brightscript = [".brs", ]
|
||||
Bro = [".bro", ]
|
||||
C = [".c", ".cats", ".h", ".idc", ".w", ]
|
||||
"C#" = [".cs", ".cake", ".cshtml", ".csx", ]
|
||||
"C++" = [".cpp", ".c++", ".cc", ".cp", ".cxx", ".h++", ".hh", ".hpp", ".hxx", ".inl", ".ipp", ".tcc", ".tpp", ".C", ".H", ]
|
||||
C-ObjDump = [".c-objdump", ]
|
||||
"C2hs Haskell" = [".chs", ]
|
||||
CLIPS = [".clp", ]
|
||||
CMake = [".cmake", ".cmake.in", ]
|
||||
COBOL = [".cob", ".cbl", ".ccp", ".cobol", ".cpy", ]
|
||||
CSS = [".css", ]
|
||||
CSV = [".csv", ]
|
||||
"Cap'n Proto" = [".capnp", ]
|
||||
CartoCSS = [".mss", ]
|
||||
Ceylon = [".ceylon", ]
|
||||
Chapel = [".chpl", ]
|
||||
ChucK = [".ck", ]
|
||||
Cirru = [".cirru", ]
|
||||
Clarion = [".clw", ]
|
||||
Clean = [".icl", ".dcl", ]
|
||||
Click = [".click", ]
|
||||
Clojure = [".clj", ".boot", ".cl2", ".cljc", ".cljs", ".cljs.hl", ".cljscm", ".cljx", ".hic", ]
|
||||
CoffeeScript = [".coffee", "._coffee", ".cjsx", ".cson", ".iced", ]
|
||||
ColdFusion = [".cfm", ".cfml", ]
|
||||
"ColdFusion CFC" = [".cfc", ]
|
||||
"Common Lisp" = [".lisp", ".asd", ".lsp", ".ny", ".podsl", ".sexp", ]
|
||||
"Component Pascal" = [".cps", ]
|
||||
Coq = [".coq", ]
|
||||
Cpp-ObjDump = [".cppobjdump", ".c++-objdump", ".c++objdump", ".cpp-objdump", ".cxx-objdump", ]
|
||||
Creole = [".creole", ]
|
||||
Crystal = [".cr", ]
|
||||
Csound = [".csd", ]
|
||||
Cucumber = [".feature", ]
|
||||
Cuda = [".cu", ".cuh", ]
|
||||
Cycript = [".cy", ]
|
||||
Cython = [".pyx", ".pxd", ".pxi", ]
|
||||
D = [".di", ]
|
||||
D-ObjDump = [".d-objdump", ]
|
||||
"DIGITAL Command Language" = [".com", ]
|
||||
DM = [".dm", ]
|
||||
"DNS Zone" = [".zone", ".arpa", ]
|
||||
"Darcs Patch" = [".darcspatch", ".dpatch", ]
|
||||
Dart = [".dart", ]
|
||||
Diff = [".diff", ".patch", ]
|
||||
Dockerfile = [".dockerfile", "Dockerfile", ]
|
||||
Dogescript = [".djs", ]
|
||||
Dylan = [".dylan", ".dyl", ".intr", ".lid", ]
|
||||
E = [".E", ]
|
||||
ECL = [".ecl", ".eclxml", ]
|
||||
Eagle = [".sch", ".brd", ]
|
||||
"Ecere Projects" = [".epj", ]
|
||||
Eiffel = [".e", ]
|
||||
Elixir = [".ex", ".exs", ]
|
||||
Elm = [".elm", ]
|
||||
"Emacs Lisp" = [".el", ".emacs", ".emacs.desktop", ]
|
||||
EmberScript = [".em", ".emberscript", ]
|
||||
Erlang = [".erl", ".escript", ".hrl", ".xrl", ".yrl", ]
|
||||
"F#" = [".fs", ".fsi", ".fsx", ]
|
||||
FLUX = [".flux", ]
|
||||
FORTRAN = [".f90", ".f", ".f03", ".f08", ".f77", ".f95", ".for", ".fpp", ]
|
||||
Factor = [".factor", ]
|
||||
Fancy = [".fy", ".fancypack", ]
|
||||
Fantom = [".fan", ]
|
||||
Formatted = [".eam.fs", ]
|
||||
Forth = [".fth", ".4th", ".forth", ".frt", ]
|
||||
FreeMarker = [".ftl", ]
|
||||
G-code = [".g", ".gco", ".gcode", ]
|
||||
GAMS = [".gms", ]
|
||||
GAP = [".gap", ".gi", ]
|
||||
GAS = [".s", ]
|
||||
GDScript = [".gd", ]
|
||||
GLSL = [".glsl", ".fp", ".frag", ".frg", ".fsh", ".fshader", ".geo", ".geom", ".glslv", ".gshader", ".shader", ".vert", ".vrx", ".vsh", ".vshader", ]
|
||||
Genshi = [".kid", ]
|
||||
"Gentoo Ebuild" = [".ebuild", ]
|
||||
"Gentoo Eclass" = [".eclass", ]
|
||||
"Gettext Catalog" = [".po", ".pot", ]
|
||||
Glyph = [".glf", ]
|
||||
Gnuplot = [".gp", ".gnu", ".gnuplot", ".plot", ".plt", ]
|
||||
Go = [".go", ]
|
||||
Golo = [".golo", ]
|
||||
Gosu = [".gst", ".gsx", ".vark", ]
|
||||
Grace = [".grace", ]
|
||||
Gradle = [".gradle", ]
|
||||
"Grammatical Framework" = [".gf", ]
|
||||
GraphQL = [".graphql", ]
|
||||
"Graphviz (DOT)" = [".dot", ".gv", ]
|
||||
Groff = [".man", ".1", ".1in", ".1m", ".1x", ".2", ".3", ".3in", ".3m", ".3qt", ".3x", ".4", ".5", ".6", ".7", ".8", ".9", ".me", ".rno", ".roff", ]
|
||||
Groovy = [".groovy", ".grt", ".gtpl", ".gvy", ]
|
||||
"Groovy Server Pages" = [".gsp", ]
|
||||
HCL = [".hcl", ".tf", ]
|
||||
HLSL = [".hlsl", ".fxh", ".hlsli", ]
|
||||
HTML = [".html", ".htm", ".html.hl", ".xht", ".xhtml", ]
|
||||
"HTML+Django" = [".mustache", ".jinja", ]
|
||||
"HTML+EEX" = [".eex", ]
|
||||
"HTML+ERB" = [".erb", ".erb.deface", ]
|
||||
"HTML+PHP" = [".phtml", ]
|
||||
HTTP = [".http", ]
|
||||
Haml = [".haml", ".haml.deface", ]
|
||||
Handlebars = [".handlebars", ".hbs", ]
|
||||
Harbour = [".hb", ]
|
||||
Haskell = [".hs", ".hsc", ]
|
||||
Haxe = [".hx", ".hxsl", ]
|
||||
Hy = [".hy", ]
|
||||
IDL = [".dlm", ]
|
||||
"IGOR Pro" = [".ipf", ]
|
||||
INI = [".ini", ".cfg", ".prefs", ".properties", ]
|
||||
"IRC log" = [".irclog", ".weechatlog", ]
|
||||
Idris = [".idr", ".lidr", ]
|
||||
"Inform 7" = [".ni", ".i7x", ]
|
||||
"Inno Setup" = [".iss", ]
|
||||
Io = [".io", ]
|
||||
Ioke = [".ik", ]
|
||||
Isabelle = [".thy", ]
|
||||
J = [".ijs", ]
|
||||
JFlex = [".flex", ".jflex", ]
|
||||
JSON = [".json", ".geojson", ".lock", ".topojson", ]
|
||||
JSON5 = [".json5", ]
|
||||
JSONLD = [".jsonld", ]
|
||||
JSONiq = [".jq", ]
|
||||
JSX = [".jsx", ]
|
||||
Jade = [".jade", ]
|
||||
Jasmin = [".j", ]
|
||||
Java = [".java", ]
|
||||
"Java Server Pages" = [".jsp", ]
|
||||
JavaScript = [".js", "._js", ".bones", ".es6", ".jake", ".jsb", ".jscad", ".jsfl", ".jsm", ".jss", ".njs", ".pac", ".sjs", ".ssjs", ".xsjs", ".xsjslib", ]
|
||||
Julia = [".jl", ]
|
||||
"Jupyter Notebook" = [".ipynb", ]
|
||||
KRL = [".krl", ]
|
||||
KiCad = [".kicad_pcb", ]
|
||||
Kit = [".kit", ]
|
||||
Kotlin = [".kt", ".ktm", ".kts", ]
|
||||
LFE = [".lfe", ]
|
||||
LLVM = [".ll", ]
|
||||
LOLCODE = [".lol", ]
|
||||
LSL = [".lsl", ".lslp", ]
|
||||
LabVIEW = [".lvproj", ]
|
||||
Lasso = [".lasso", ".las", ".lasso8", ".lasso9", ".ldml", ]
|
||||
Latte = [".latte", ]
|
||||
Lean = [".lean", ".hlean", ]
|
||||
Less = [".less", ]
|
||||
Lex = [".lex", ]
|
||||
LilyPond = [".ly", ".ily", ]
|
||||
"Linker Script" = [".ld", ".lds", ]
|
||||
Liquid = [".liquid", ]
|
||||
"Literate Agda" = [".lagda", ]
|
||||
"Literate CoffeeScript" = [".litcoffee", ]
|
||||
"Literate Haskell" = [".lhs", ]
|
||||
LiveScript = [".ls", "._ls", ]
|
||||
Logos = [".xm", ".x", ".xi", ]
|
||||
Logtalk = [".lgt", ".logtalk", ]
|
||||
LookML = [".lookml", ]
|
||||
Lua = [".lua", ".nse", ".pd_lua", ".rbxs", ".wlua", ]
|
||||
M = [".mumps", ]
|
||||
M4 = [".m4", ]
|
||||
MAXScript = [".mcr", ]
|
||||
MTML = [".mtml", ]
|
||||
MUF = [".muf", ]
|
||||
Makefile = [".mak", ".mk", ".mkfile", "Makefile", ]
|
||||
Mako = [".mako", ".mao", ]
|
||||
Maple = [".mpl", ]
|
||||
Markdown = [".md", ".markdown", ".mkd", ".mkdn", ".mkdown", ".ron", ]
|
||||
Mask = [".mask", ]
|
||||
Mathematica = [".mathematica", ".cdf", ".ma", ".mt", ".nb", ".nbp", ".wl", ".wlt", ]
|
||||
Matlab = [".matlab", ]
|
||||
Max = [".maxpat", ".maxhelp", ".maxproj", ".mxt", ".pat", ]
|
||||
MediaWiki = [".mediawiki", ".wiki", ]
|
||||
Metal = [".metal", ]
|
||||
MiniD = [".minid", ]
|
||||
Mirah = [".druby", ".duby", ".mir", ".mirah", ]
|
||||
Modelica = [".mo", ]
|
||||
"Module Management System" = [".mms", ".mmk", ]
|
||||
Monkey = [".monkey", ]
|
||||
MoonScript = [".moon", ]
|
||||
Myghty = [".myt", ]
|
||||
NSIS = [".nsi", ".nsh", ]
|
||||
NetLinx = [".axs", ".axi", ]
|
||||
"NetLinx+ERB" = [".axs.erb", ".axi.erb", ]
|
||||
NetLogo = [".nlogo", ]
|
||||
Nginx = [".nginxconf", ]
|
||||
Nimrod = [".nim", ".nimrod", ]
|
||||
Ninja = [".ninja", ]
|
||||
Nit = [".nit", ]
|
||||
Nix = [".nix", ]
|
||||
Nu = [".nu", ]
|
||||
NumPy = [".numpy", ".numpyw", ".numsc", ]
|
||||
OCaml = [".ml", ".eliom", ".eliomi", ".ml4", ".mli", ".mll", ".mly", ]
|
||||
ObjDump = [".objdump", ]
|
||||
"Objective-C++" = [".mm", ]
|
||||
Objective-J = [".sj", ]
|
||||
Octave = [".oct", ]
|
||||
Omgrofl = [".omgrofl", ]
|
||||
Opa = [".opa", ]
|
||||
Opal = [".opal", ]
|
||||
OpenCL = [".cl", ".opencl", ]
|
||||
"OpenEdge ABL" = [".p", ]
|
||||
OpenSCAD = [".scad", ]
|
||||
Org = [".org", ]
|
||||
Ox = [".ox", ".oxh", ".oxo", ]
|
||||
Oxygene = [".oxygene", ]
|
||||
Oz = [".oz", ]
|
||||
PAWN = [".pwn", ]
|
||||
PHP = [".php", ".aw", ".ctp", ".php3", ".php4", ".php5", ".phps", ".phpt", ]
|
||||
"POV-Ray SDL" = [".pov", ]
|
||||
Pan = [".pan", ]
|
||||
Papyrus = [".psc", ]
|
||||
Parrot = [".parrot", ]
|
||||
"Parrot Assembly" = [".pasm", ]
|
||||
"Parrot Internal Representation" = [".pir", ]
|
||||
Pascal = [".pas", ".dfm", ".dpr", ".lpr", ]
|
||||
Perl = [".pl", ".al", ".perl", ".ph", ".plx", ".pm", ".psgi", ".t", ]
|
||||
Perl6 = [".6pl", ".6pm", ".nqp", ".p6", ".p6l", ".p6m", ".pl6", ".pm6", ]
|
||||
Pickle = [".pkl", ]
|
||||
PigLatin = [".pig", ]
|
||||
Pike = [".pike", ".pmod", ]
|
||||
Pod = [".pod", ]
|
||||
PogoScript = [".pogo", ]
|
||||
Pony = [".pony", ]
|
||||
PostScript = [".ps", ".eps", ]
|
||||
PowerShell = [".ps1", ".psd1", ".psm1", ]
|
||||
Processing = [".pde", ]
|
||||
Prolog = [".prolog", ".yap", ]
|
||||
"Propeller Spin" = [".spin", ]
|
||||
"Protocol Buffer" = [".proto", ]
|
||||
"Public Key" = [".pub", ]
|
||||
"Pure Data" = [".pd", ]
|
||||
PureBasic = [".pb", ".pbi", ]
|
||||
PureScript = [".purs", ]
|
||||
Python = [".py", ".bzl", ".gyp", ".lmi", ".pyde", ".pyp", ".pyt", ".pyw", ".tac", ".wsgi", ".xpy", ]
|
||||
"Python traceback" = [".pytb", ]
|
||||
QML = [".qml", ".qbs", ]
|
||||
QMake = [".pri", ]
|
||||
R = [".r", ".rd", ".rsx", ]
|
||||
RAML = [".raml", ]
|
||||
RDoc = [".rdoc", ]
|
||||
REALbasic = [".rbbas", ".rbfrm", ".rbmnu", ".rbres", ".rbtbar", ".rbuistate", ]
|
||||
RHTML = [".rhtml", ]
|
||||
RMarkdown = [".rmd", ]
|
||||
Racket = [".rkt", ".rktd", ".rktl", ".scrbl", ]
|
||||
"Ragel in Ruby Host" = [".rl", ]
|
||||
"Raw token data" = [".raw", ]
|
||||
Rebol = [".reb", ".r2", ".r3", ".rebol", ]
|
||||
Red = [".red", ".reds", ]
|
||||
Redcode = [".cw", ]
|
||||
"Ren'Py" = [".rpy", ]
|
||||
RenderScript = [".rsh", ]
|
||||
RobotFramework = [".robot", ]
|
||||
Rouge = [".rg", ]
|
||||
Ruby = [".rb", ".builder", ".gemspec", ".god", ".irbrc", ".jbuilder", ".mspec", ".podspec", ".rabl", ".rake", ".rbuild", ".rbw", ".rbx", ".ru", ".ruby", ".thor", ".watchr", ]
|
||||
Rust = [".rs", ".rs.in", ]
|
||||
SAS = [".sas", ]
|
||||
SCSS = [".scss", ]
|
||||
SMT = [".smt2", ".smt", ]
|
||||
SPARQL = [".sparql", ".rq", ]
|
||||
SQF = [".sqf", ".hqf", ]
|
||||
SQL = [".pls", ".pck", ".pkb", ".pks", ".plb", ".plsql", ".sql", ".cql", ".ddl", ".prc", ".tab", ".udf", ".viw", ".db2", ]
|
||||
STON = [".ston", ]
|
||||
SVG = [".svg", ]
|
||||
Sage = [".sage", ".sagews", ]
|
||||
SaltStack = [".sls", ]
|
||||
Sass = [".sass", ]
|
||||
Scala = [".scala", ".sbt", ]
|
||||
Scaml = [".scaml", ]
|
||||
Scheme = [".scm", ".sld", ".sps", ".ss", ]
|
||||
Scilab = [".sci", ".sce", ]
|
||||
Self = [".self", ]
|
||||
Shell = [".sh", ".bash", ".bats", ".command", ".ksh", ".sh.in", ".tmux", ".tool", ".zsh", ]
|
||||
ShellSession = [".sh-session", ]
|
||||
Shen = [".shen", ]
|
||||
Slash = [".sl", ]
|
||||
Slim = [".slim", ]
|
||||
Smali = [".smali", ]
|
||||
Smalltalk = [".st", ]
|
||||
Smarty = [".tpl", ]
|
||||
Solidity = [".sol", ]
|
||||
SourcePawn = [".sp", ".sma", ]
|
||||
Squirrel = [".nut", ]
|
||||
Stan = [".stan", ]
|
||||
"Standard ML" = [".ML", ".fun", ".sig", ".sml", ]
|
||||
Stata = [".do", ".ado", ".doh", ".ihlp", ".mata", ".matah", ".sthlp", ]
|
||||
Stylus = [".styl", ]
|
||||
SuperCollider = [".scd", ]
|
||||
Swift = [".swift", ]
|
||||
SystemVerilog = [".sv", ".svh", ".vh", ]
|
||||
TOML = [".toml", ]
|
||||
TXL = [".txl", ]
|
||||
Tcl = [".tcl", ".adp", ".tm", ]
|
||||
Tcsh = [".tcsh", ".csh", ]
|
||||
TeX = [".tex", ".aux", ".bbx", ".bib", ".cbx", ".dtx", ".ins", ".lbx", ".ltx", ".mkii", ".mkiv", ".mkvi", ".sty", ".toc", ]
|
||||
Tea = [".tea", ]
|
||||
Text = [".txt", ".no", ]
|
||||
Textile = [".textile", ]
|
||||
Thrift = [".thrift", ]
|
||||
Turing = [".tu", ]
|
||||
Turtle = [".ttl", ]
|
||||
Twig = [".twig", ]
|
||||
TypeScript = [".ts", ".tsx", ]
|
||||
"Unified Parallel C" = [".upc", ]
|
||||
"Unity3D Asset" = [".anim", ".asset", ".mat", ".meta", ".prefab", ".unity", ]
|
||||
Uno = [".uno", ]
|
||||
UnrealScript = [".uc", ]
|
||||
UrWeb = [".ur", ".urs", ]
|
||||
VCL = [".vcl", ]
|
||||
VHDL = [".vhdl", ".vhd", ".vhf", ".vhi", ".vho", ".vhs", ".vht", ".vhw", ]
|
||||
Vala = [".vala", ".vapi", ]
|
||||
Verilog = [".veo", ]
|
||||
VimL = [".vim", ]
|
||||
"Visual Basic" = [".vb", ".bas", ".frm", ".frx", ".vba", ".vbhtml", ".vbs", ]
|
||||
Volt = [".volt", ]
|
||||
Vue = [".vue", ]
|
||||
"Web Ontology Language" = [".owl", ]
|
||||
WebAssembly = [".wat", ]
|
||||
WebIDL = [".webidl", ]
|
||||
X10 = [".x10", ]
|
||||
XC = [".xc", ]
|
||||
XML = [".xml", ".ant", ".axml", ".ccxml", ".clixml", ".cproject", ".csl", ".csproj", ".ct", ".dita", ".ditamap", ".ditaval", ".dll.config", ".dotsettings", ".filters", ".fsproj", ".fxml", ".glade", ".grxml", ".iml", ".ivy", ".jelly", ".jsproj", ".kml", ".launch", ".mdpolicy", ".mxml", ".nproj", ".nuspec", ".odd", ".osm", ".plist", ".props", ".ps1xml", ".psc1", ".pt", ".rdf", ".rss", ".scxml", ".srdf", ".storyboard", ".stTheme", ".sublime-snippet", ".targets", ".tmCommand", ".tml", ".tmLanguage", ".tmPreferences", ".tmSnippet", ".tmTheme", ".ui", ".urdf", ".ux", ".vbproj", ".vcxproj", ".vssettings", ".vxml", ".wsdl", ".wsf", ".wxi", ".wxl", ".wxs", ".x3d", ".xacro", ".xaml", ".xib", ".xlf", ".xliff", ".xmi", ".xml.dist", ".xproj", ".xsd", ".xul", ".zcml", ]
|
||||
XPages = [".xsp-config", ".xsp.metadata", ]
|
||||
XProc = [".xpl", ".xproc", ]
|
||||
XQuery = [".xquery", ".xq", ".xql", ".xqm", ".xqy", ]
|
||||
XS = [".xs", ]
|
||||
XSLT = [".xslt", ".xsl", ]
|
||||
Xojo = [".xojo_code", ".xojo_menu", ".xojo_report", ".xojo_script", ".xojo_toolbar", ".xojo_window", ]
|
||||
Xtend = [".xtend", ]
|
||||
YAML = [".yml", ".reek", ".rviz", ".sublime-syntax", ".syntax", ".yaml", ".yaml-tmlanguage", ]
|
||||
YANG = [".yang", ]
|
||||
Yacc = [".y", ".yacc", ".yy", ]
|
||||
Zephir = [".zep", ]
|
||||
Zig = [".zig", ]
|
||||
Zimpl = [".zimpl", ".zmpl", ".zpl", ]
|
||||
desktop = [".desktop", ".desktop.in", ]
|
||||
eC = [".ec", ".eh", ]
|
||||
edn = [".edn", ]
|
||||
fish = [".fish", ]
|
||||
mupad = [".mu", ]
|
||||
nesC = [".nc", ]
|
||||
ooc = [".ooc", ]
|
||||
reStructuredText = [".rst", ".rest", ".rest.txt", ".rst.txt", ]
|
||||
wisp = [".wisp", ]
|
||||
xBase = [".prg", ".prw", ]
|
||||
|
||||
[docs_blacklist_extensions]
|
||||
# Disable docs for these extensions of text files and scripts that are not programming languages of function, classes and methods
|
||||
docs_blacklist = ['sql', 'txt', 'yaml', 'json', 'xml', 'md', 'rst', 'rest', 'rest.txt', 'rst.txt', 'mdpolicy', 'mdown', 'markdown', 'mdwn', 'mkd', 'mkdn', 'mkdown', 'sh']
|
||||
126
apps/utils/pr_agent/settings/pr_add_docs.toml
Normal file
126
apps/utils/pr_agent/settings/pr_add_docs.toml
Normal file
@ -0,0 +1,126 @@
|
||||
[pr_add_docs_prompt]
|
||||
system="""你是 PR-Doc, 一个专门为 Pull Request (PR) 中的代码组件生成文档的语言模型.
|
||||
你的任务是为 PR Diff 中的代码组件生成 {{ docs_for_language }}.
|
||||
|
||||
|
||||
PR Diff 格式示例:
|
||||
======
|
||||
## file: 'src/file1.py'
|
||||
|
||||
@@ -12,3 +12,4 @@ def func1():
|
||||
__new hunk__
|
||||
12 代码行1 在 PR 中保持不变
|
||||
14 +PR 中添加的代码行1
|
||||
15 +PR 中添加的代码行2
|
||||
16 代码行2 在 PR 中保持不变
|
||||
__old hunk__
|
||||
代码行1 在 PR 中保持不变
|
||||
-PR 中已删除的代码行
|
||||
代码行2 在 PR 中保持不变
|
||||
|
||||
@@ ... @@ def func2():
|
||||
__new hunk__
|
||||
...
|
||||
__old hunk__
|
||||
...
|
||||
|
||||
|
||||
## file: 'src/file2.py'
|
||||
...
|
||||
======
|
||||
|
||||
|
||||
具体说明:
|
||||
- 尝试识别已编辑/添加的未文档化的代码组件 (类/函数/方法...), 并为每个组件生成 {{ docs_for_language }}.
|
||||
- 如果 PR 中存在已文档化的 (任何类型的 {{ language }} 文档) 代码组件, 则不要为它们生成 {{ docs_for_language }}.
|
||||
- 忽略未完全出现在 '__new hunk__' 部分中的代码组件. 例如, 你必须看到组件的 header 和 body.
|
||||
- 确保 {{ docs_for_language }} 以标准的 {{ language }} {{ docs_for_language }} 符号开始和结束.
|
||||
- {{ docs_for_language }} 应采用标准格式.
|
||||
- 提供应添加 {{ docs_for_language }} 的确切行号 (包括在内).
|
||||
|
||||
|
||||
{%- if extra_instructions %}
|
||||
|
||||
用户的额外说明:
|
||||
======
|
||||
{{ extra_instructions }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
|
||||
你必须使用以下 YAML 模式来格式化你的答案:
|
||||
```yaml
|
||||
Code Documentation:
|
||||
type: array
|
||||
uniqueItems: true
|
||||
items:
|
||||
relevant file:
|
||||
type: string
|
||||
description: The full file path of the relevant file.
|
||||
relevant line:
|
||||
type: integer
|
||||
description: |-
|
||||
The relevant line number from a '__new hunk__' section where the {{ docs_for_language }} should be added.
|
||||
doc placement:
|
||||
type: string
|
||||
enum:
|
||||
- before
|
||||
- after
|
||||
description: |-
|
||||
The {{ docs_for_language }} placement relative to the relevant line (code component).
|
||||
For example, in Python the docs are placed after the function signature, but in Java they are placed before.
|
||||
documentation:
|
||||
type: string
|
||||
description: |-
|
||||
The {{ docs_for_language }} content. It should be complete, correctly formatted and indented, and without line numbers.
|
||||
```
|
||||
|
||||
输出示例:
|
||||
```yaml
|
||||
Code Documentation:
|
||||
- relevant file: |-
|
||||
src/file1.py
|
||||
relevant lines: 12
|
||||
doc placement: after
|
||||
documentation: |-
|
||||
\"\"\"
|
||||
This is a python docstring for func1.
|
||||
\"\"\"
|
||||
- ...
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
每个 YAML 输出都必须在新行之后, 缩进, 并带有块标量指示符 ('|-').
|
||||
不要在答案中重复提示, 并避免输出 'type' 和 'description' 字段.
|
||||
"""
|
||||
|
||||
user="""PR Info:
|
||||
|
||||
Title: '{{ title }}'
|
||||
|
||||
Branch: '{{ branch }}'
|
||||
|
||||
{%- if description %}
|
||||
|
||||
Description:
|
||||
======
|
||||
{{ description|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
{%- if language %}
|
||||
|
||||
Main PR language: '{{language}}'
|
||||
{%- endif %}
|
||||
|
||||
|
||||
The PR Diff:
|
||||
======
|
||||
{{ diff|trim }}
|
||||
======
|
||||
|
||||
|
||||
Response (should be a valid YAML, and nothing else):
|
||||
```yaml
|
||||
"""
|
||||
166
apps/utils/pr_agent/settings/pr_code_suggestions_prompts.toml
Normal file
166
apps/utils/pr_agent/settings/pr_code_suggestions_prompts.toml
Normal file
@ -0,0 +1,166 @@
|
||||
[pr_code_suggestions_prompt]
|
||||
system="""你是 PR-Reviewer, 一个专注于 Pull Request (PR) 代码分析和建议的 AI.
|
||||
{%- if not focus_only_on_problems %}
|
||||
你的任务是检查所提供的代码 diff, 重点关注新增代码 (行以 '+' 为前缀), 并提供简洁且可操作的建议以修复可能的 bug 和问题, 同时提升代码质量和性能.
|
||||
{%- else %}
|
||||
你的任务是检查所提供的代码 diff, 重点关注新增代码 (行以 '+' 为前缀), 并提供简洁且可操作的建议以修复关键 bug 和问题.
|
||||
{%- endif %}
|
||||
|
||||
PR 代码差异将采用以下结构化格式:
|
||||
======
|
||||
## File: 'src/file1.py'
|
||||
{%- if is_ai_metadata %}
|
||||
### AI-generated changes summary:
|
||||
* ...
|
||||
* ...
|
||||
{%- endif %}
|
||||
|
||||
@@ ... @@ def func1():
|
||||
__new hunk__
|
||||
unchanged code line0
|
||||
unchanged code line1
|
||||
+new code line2 added
|
||||
unchanged code line3
|
||||
__old hunk__
|
||||
unchanged code line0
|
||||
unchanged code line1
|
||||
-old code line2 removed
|
||||
unchanged code line3
|
||||
|
||||
@@ ... @@ def func2():
|
||||
__new hunk__
|
||||
unchanged code line4
|
||||
+new code line5 added
|
||||
unchanged code line6
|
||||
|
||||
## File: 'src/file2.py'
|
||||
...
|
||||
======
|
||||
|
||||
关于上述结构化差异格式的重要说明:
|
||||
1. 每个 PR 代码块都被分解为单独的 '__new hunk__' 和 '__old hunk__' 部分:
|
||||
- '__new hunk__' 部分显示 PR 更改之后的代码块.
|
||||
- '__old hunk__' 部分显示 PR 更改之前的代码块. 如果没有从块中删除任何代码, '__old hunk__' 部分将被省略.
|
||||
2. 差异使用行前缀来显示更改:
|
||||
'+' → 添加的新代码行 (仅出现在 '__new hunk__' 中)
|
||||
'-' → 删除的代码行 (仅出现在 '__old hunk__' 中)
|
||||
' ' → 未更改的上下文行 (将出现在两个部分中)
|
||||
{%- if is_ai_metadata %}
|
||||
3. 如果可用, 每个文件的差异之前都会有一个 AI 生成的摘要, 其中包含对更改的高级概述. 请注意, 此摘要可能不完全准确或完整.
|
||||
{%- endif %}
|
||||
|
||||
|
||||
生成代码建议的具体指南:
|
||||
{%- if not focus_only_on_problems %}
|
||||
- 提供最多{{num_code_suggestions}}条不同的且有见地的代码建议.
|
||||
{%- else %}
|
||||
- 提供最多{{num_code_suggestions}}条不同的且有见地的代码建议. 如果没有适用的建议,则返回较少的建议.
|
||||
{%- endif %}
|
||||
- 不要建议实施与'-'行相比,'+'行中已存在的更改.
|
||||
- 仅关注PR中引入的新代码( '__new hunk__' 部分中的'+'行).
|
||||
{%- if not focus_only_on_problems %}
|
||||
- 优先考虑解决PR代码中潜在问题,关键问题和错误的建议. 避免重复PR中已实施的更改. 如果没有适用的建议,则返回一个空列表.
|
||||
- 不要建议添加文档字符串,类型提示或注释,删除未使用的导入,或使用更具体的异常类型.
|
||||
{%- else %}
|
||||
- 仅提供解决PR代码中关键问题和错误的建议. 如果没有相关的建议,则返回一个空列表.
|
||||
- 不要建议更改软件包版本,添加缺少的导入语句或声明未定义的变量.
|
||||
{%- endif %}
|
||||
- 在您的回复中提及代码元素(变量,名称或文件)时,请用反引号(``)将其括起来. 例如:"验证`user_id`是否...".
|
||||
- 请注意,您只能看到更改的代码段(PR中的diff hunks),而不是整个代码库. 避免可能重复现有功能的建议,或质疑可能在代码库中其他位置定义的代码元素(如变量声明或导入语句).
|
||||
|
||||
{%- if extra_instructions %}
|
||||
|
||||
|
||||
额外用户提供的指令(应优先处理):
|
||||
======
|
||||
{{ extra_instructions }}
|
||||
回答必须使用简体中文,并且必须使用英文标点符号!
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
|
||||
输出必须是一个YAML对象,等同于类型 $PRCodeSuggestions, 根据以下Pydantic定义:
|
||||
=====
|
||||
class CodeSuggestion(BaseModel):
|
||||
relevant_file: str = Field(description="相关文件的完整路径")
|
||||
language: str = Field(description="相关文件使用的编程语言")
|
||||
suggestion_content: str = Field(description="一个可操作的建议,用于增强、改进或修复PR中引入的新代码. 不要在这里呈现实际的代码片段, 只需要建议. 简明扼要")
|
||||
existing_code: str = Field(description="一个简短的代码片段, 来自PR更改后的 '__new hunk__' 部分, 该建议旨在增强或修复. 仅包括完整的代码行. 如果需要, 使用省略号 (...) 来保持简洁. 此片段应代表目标改进的特定PR代码.")
|
||||
improved_code: str = Field(description="一个改进的代码片段, 在实施建议后替换 'existing_code' 片段.")
|
||||
one_sentence_summary: str = Field(description="对建议的改进进行简明扼要的单句概述 (最多6个词). 关注 'what'. 保持通用性, 避免方法或变量名称,回答尽量使用简体中文,并且必须使用英文标点符号.")
|
||||
{%- if not focus_only_on_problems %}
|
||||
label: str = Field(description="一个单一的、描述性的标签, 最能描述建议类型. 可能的标签包括 '安全', '可能的错误', '可能的问题', '性能', '增强', '最佳实践', '可维护性', '拼写错误'. 其他相关标签也可以接受.")
|
||||
{%- else %}
|
||||
label: str = Field(description="一个单一的、描述性的标签, 最能描述建议类型. 可能的标签包括 '安全', '关键漏洞', '一般'. '一般' 部分应用于解决主要问题, 但不一定是关键级别的建议.")
|
||||
{%- endif %}
|
||||
|
||||
|
||||
class PRCodeSuggestions(BaseModel):
|
||||
code_suggestions: List[CodeSuggestion]
|
||||
=====
|
||||
|
||||
|
||||
示例输出:
|
||||
```yaml
|
||||
code_suggestions:
|
||||
- relevant_file: |
|
||||
src/file1.py
|
||||
language: |
|
||||
python
|
||||
suggestion_content: |
|
||||
...
|
||||
existing_code: |
|
||||
...
|
||||
improved_code: |
|
||||
...
|
||||
one_sentence_summary: |
|
||||
...
|
||||
label: |
|
||||
...
|
||||
```
|
||||
|
||||
每个YAML输出必须在新的一行之后, 缩进, 并带有块标量指示符 ('|').
|
||||
"""
|
||||
|
||||
user="""--PR 信息--
|
||||
|
||||
标题: '{{title}}'
|
||||
|
||||
{%- if date %}
|
||||
|
||||
今日日期: {{date}}
|
||||
{%- endif %}
|
||||
|
||||
PR 差异:
|
||||
======
|
||||
{{ diff_no_line_numbers|trim }}
|
||||
======
|
||||
|
||||
{%- if duplicate_prompt_examples %}
|
||||
|
||||
|
||||
示例输出:
|
||||
```yaml
|
||||
code_suggestions:
|
||||
- relevant_file: |
|
||||
src/file1.py
|
||||
language: |
|
||||
python
|
||||
suggestion_content: |
|
||||
...
|
||||
existing_code: |
|
||||
...
|
||||
improved_code: |
|
||||
...
|
||||
one_sentence_summary: |
|
||||
...
|
||||
label: |
|
||||
...
|
||||
```
|
||||
(替换 '...' 为实际内容)
|
||||
{%- endif %}
|
||||
|
||||
|
||||
响应(应该是有效的YAML,且没有其他内容):
|
||||
```yaml
|
||||
"""
|
||||
@ -0,0 +1,146 @@
|
||||
[pr_code_suggestions_reflect_prompt]
|
||||
system="""你是一个AI语言模型,专门用于审查和评估Pull Request (PR)的代码建议.
|
||||
你的任务是分析PR代码差异,并评估一组AI生成的代码建议.这些建议旨在解决潜在的错误和问题,并增强PR中引入的新代码.
|
||||
|
||||
仔细检查每个建议,评估其在PR上下文中的质量,相关性和准确性.请记住,建议的正确性和准确性可能有所不同.你的评估应基于每个建议与实际PR代码差异之间的彻底比较.
|
||||
考虑每个建议的以下组成部分:
|
||||
1. 'one_sentence_summary' - 建议目的的简短摘要
|
||||
2. 'suggestion_content' - 详细的建议内容,解释建议的修改
|
||||
3. 'existing_code' - 来自PR代码差异中__new hunk__部分的__代码片段,建议针对该代码片段
|
||||
4. 'improved_code' - 代码片段,演示在应用建议后'existing_code'应如何
|
||||
|
||||
要特别警惕以下建议:
|
||||
- 忽略PR中的关键细节
|
||||
- 'improved_code'部分未准确反映建议的更改,与'existing_code'相关
|
||||
- 违反或忽略PR修改的部分
|
||||
在这种情况下,请为建议分配0分.
|
||||
|
||||
通过评估每个有效建议对PR的正确性,质量和功能性的潜在影响来对其进行评分.
|
||||
此外,你还应检测与'existing_code'代码片段对应的'__new hunk__'部分中的行号.
|
||||
|
||||
评估的关键准则:
|
||||
- 彻底检查建议内容和相应的PR代码差异.警惕每个建议中的潜在错误,确保它们在逻辑上合理,准确,并且直接来自PR代码差异.
|
||||
- 将你的审查范围扩展到超出明确提及的代码行,以包含周围的上下文,验证建议的上下文准确性.
|
||||
- 通过确认'existing_code'字段是否匹配或准确地来自PR代码差异的'__new hunk__'部分中的代码行来验证它.
|
||||
- 确保'improved_code'部分在应用建议的修改后,准确反映'existing_code'段.
|
||||
- 应用细致入微的评分系统:
|
||||
- 为解决关键问题(如重大错误或安全问题)的建议保留高分(8-10).
|
||||
- 为解决次要问题,改进代码风格,增强可读性或提高可维护性的建议分配中等分数(3-7).
|
||||
- 避免为虽然正确但仅提供边际改进或优化的建议提高分数.
|
||||
- 在你的反馈中保持建议的原始顺序,与其输入顺序相对应.
|
||||
|
||||
额外的评分注意事项:
|
||||
- 如果建议不可操作,并且仅要求用户验证或确保更改,则将其分数降低1-2分.
|
||||
- 为以下目的的建议分配0分:
|
||||
- 添加文档字符串,类型提示或注释
|
||||
- 删除未使用的导入或变量
|
||||
- 使用更具体的异常类型.
|
||||
|
||||
|
||||
|
||||
PR代码差异将以以下结构化格式呈现:
|
||||
======
|
||||
## File: 'src/file1.py'
|
||||
{%- if is_ai_metadata %}
|
||||
### AI-generated changes summary:
|
||||
* ...
|
||||
* ...
|
||||
{%- endif %}
|
||||
|
||||
@@ ... @@ def func1():
|
||||
__new hunk__
|
||||
11 unchanged code line0
|
||||
12 unchanged code line1
|
||||
13 +new code line2 added
|
||||
14 unchanged code line3
|
||||
__old hunk__
|
||||
unchanged code line0
|
||||
unchanged code line1
|
||||
-old code line2 removed
|
||||
unchanged code line3
|
||||
|
||||
@@ ... @@ def func2():
|
||||
__new hunk__
|
||||
...
|
||||
__old hunk__
|
||||
...
|
||||
|
||||
|
||||
## File: 'src/file2.py'
|
||||
...
|
||||
======
|
||||
- 在上面的格式中,差异针对每个代码块组织成单独的'__new hunk__'和'__old hunk__'部分.'__new hunk__'包含更新的代码,而'__old hunk__'显示删除的代码.如果在特定块中未添加或删除任何代码,则将省略相应的部分.
|
||||
- 为'__new hunk__'部分包含行号,以便在代码建议中引用特定行.这些数字仅供参考,不属于实际代码的一部分.
|
||||
- 代码行以符号作为前缀: '+'表示PR中添加的新代码, '-'表示删除的代码, ' '表示未更改的代码.
|
||||
{%- if is_ai_metadata %}
|
||||
- 当可用时,AI生成的摘要将在每个文件的差异之前出现,其中包含更改的高级概述.请注意,此摘要可能不完全准确或全面.
|
||||
{%- endif %}
|
||||
|
||||
|
||||
输出必须是等效于类型$PRCodeSuggestionsFeedback的YAML对象,根据以下Pydantic定义:
|
||||
=====
|
||||
class CodeSuggestionFeedback(BaseModel):
|
||||
suggestion_summary: str = Field(description="从输入重复")
|
||||
relevant_file: str = Field(description="从输入重复")
|
||||
relevant_lines_start: int = Field(description="相关的行号,来自'__new hunk__'部分,建议开始的位置(包括在内).应该从hunk行号中导出,并对应于相关的'现有代码'片段的开头")
|
||||
relevant_lines_end: int = Field(description="相关的行号,来自'__new hunk__'部分,建议结束的位置(包括在内).应该从hunk行号中导出,并对应于相关的'现有代码'片段的结尾")
|
||||
suggestion_score: int = Field(description="评估建议并分配一个从0到10的分数.如果建议是错误的,则给0.对于有效的建议,从1(最低影响/重要性)到10(最高影响/重要性)评分.")
|
||||
why: str = Field(description="用1-2句话简要解释给出的分数,重点关注建议的影响,相关性和准确性.")
|
||||
|
||||
class PRCodeSuggestionsFeedback(BaseModel):
|
||||
code_suggestions: List[CodeSuggestionFeedback]
|
||||
=====
|
||||
|
||||
|
||||
Example output:
|
||||
```yaml
|
||||
code_suggestions:
|
||||
- suggestion_summary: |
|
||||
在此处使用更具描述性的变量名
|
||||
relevant_file: "src/file1.py"
|
||||
relevant_lines_start: 13
|
||||
relevant_lines_end: 14
|
||||
suggestion_score: 6
|
||||
why: |
|
||||
变量名“t”不够具有描述性,使用更具描述性的名称可以增强代码的可读性和可维护性.
|
||||
- ...
|
||||
```
|
||||
|
||||
|
||||
每个YAML输出 MUST 在换行符后,缩进,带有块标量指示符 ('|').
|
||||
"""
|
||||
|
||||
user="""你将获得一个Pull Request (PR)代码差异:
|
||||
======
|
||||
{{ diff|trim }}
|
||||
======
|
||||
|
||||
|
||||
以下是{{ num_code_suggestions }} 个AI生成的代码建议,用于增强 Pull Request:
|
||||
======
|
||||
{{ suggestion_str|trim }}
|
||||
======
|
||||
|
||||
|
||||
{%- if duplicate_prompt_examples %}
|
||||
|
||||
|
||||
Example output:
|
||||
```yaml
|
||||
code_suggestions:
|
||||
- suggestion_summary: |
|
||||
...
|
||||
relevant_file: "..."
|
||||
relevant_lines_start: ...
|
||||
relevant_lines_end: ...
|
||||
suggestion_score: ...
|
||||
why: |
|
||||
...
|
||||
- ...
|
||||
```
|
||||
(将'...'替换为实际内容)
|
||||
{%- endif %}
|
||||
|
||||
响应 (应该是一个有效的YAML,而不是其他内容):
|
||||
```yaml
|
||||
"""
|
||||
86
apps/utils/pr_agent/settings/pr_custom_labels.toml
Normal file
86
apps/utils/pr_agent/settings/pr_custom_labels.toml
Normal file
@ -0,0 +1,86 @@
|
||||
[pr_custom_labels_prompt]
|
||||
system="""你是PR-Reviewer, 一个旨在审查Git Pull Request (PR)的语言模型.
|
||||
你的任务是提供描述PR内容的标签.
|
||||
{%- if enable_custom_labels %}
|
||||
仔细阅读标签名称和提供的描述, 并决定该标签是否与PR相关.
|
||||
{%- endif %}
|
||||
|
||||
{%- if extra_instructions %}
|
||||
|
||||
来自用户的额外指示:
|
||||
======
|
||||
{{ extra_instructions }}
|
||||
======
|
||||
{% endif %}
|
||||
|
||||
|
||||
输出必须是一个等同于类型 $Labels 的 YAML 对象, 根据以下 Pydantic 定义:
|
||||
======
|
||||
{%- if enable_custom_labels %}
|
||||
|
||||
{{ custom_labels_class }}
|
||||
|
||||
{%- else %}
|
||||
class Label(str, Enum):
|
||||
bug_fix = "Bug 修复"
|
||||
tests = "测试"
|
||||
enhancement = "增强"
|
||||
documentation = "文档"
|
||||
other = "其他"
|
||||
{%- endif %}
|
||||
|
||||
class Labels(BaseModel):
|
||||
labels: List[Label] = Field(min_items=0, description="选择描述PR内容的相关自定义标签, 并返回它们的键. 使用 Label 对象的值来更好地理解标签含义.")
|
||||
======
|
||||
|
||||
|
||||
示例输出:
|
||||
|
||||
```yaml
|
||||
labels:
|
||||
- ...
|
||||
- ...
|
||||
```
|
||||
|
||||
答案应该是一个有效的YAML,仅此而已.
|
||||
"""
|
||||
|
||||
user="""PR 信息:
|
||||
|
||||
之前的标题: '{{title}}'
|
||||
|
||||
分支: '{{ branch }}'
|
||||
|
||||
{%- if description %}
|
||||
|
||||
描述:
|
||||
======
|
||||
{{ description|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
{%- if language %}
|
||||
|
||||
主要的 PR 语言: '{{ language }}'
|
||||
{%- endif %}
|
||||
{%- if commit_messages_str %}
|
||||
|
||||
|
||||
提交信息:
|
||||
======
|
||||
{{ commit_messages_str|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
|
||||
PR Git 差异:
|
||||
======
|
||||
{{ diff|trim }}
|
||||
======
|
||||
|
||||
请注意, 差异正文中的行以符号作为前缀, 该符号代表更改类型: '-' 代表删除, '+' 代表添加, 以及 ' ' (空格) 代表未更改的行.
|
||||
|
||||
|
||||
回复 (应该是一个有效的YAML, 仅此而已):
|
||||
```yaml
|
||||
"""
|
||||
167
apps/utils/pr_agent/settings/pr_description_prompts.toml
Normal file
167
apps/utils/pr_agent/settings/pr_description_prompts.toml
Normal file
@ -0,0 +1,167 @@
|
||||
[pr_description_prompt]
|
||||
system="""你是PR-Reviewer, 一个旨在审查Git Pull Request (PR)的语言模型
|
||||
你的任务是为PR内容提供完整的描述 - 类型, 描述, 标题和文件漫游
|
||||
- 专注于新的PR代码 (在'PR Git Diff'部分中以'+'开头的行)
|
||||
- 请记住, 'Previous title', 'Previous description'和'Commit messages'部分可能是部分的, 简单的, 内容不足的或过时的. 因此, 将它们与PR diff代码进行比较, 仅将它们用作参考.
|
||||
- 生成的标题和描述应优先考虑最重要的更改.
|
||||
- 如果需要, 每个YAML输出都应使用块标量指示符 ('|')
|
||||
- 当引用代码中的变量, 名称或文件路径时, 使用反引号 (`) 而不是单引号 (').
|
||||
|
||||
{%- if extra_instructions %}
|
||||
|
||||
用户的额外指示:
|
||||
=====
|
||||
{{extra_instructions}}
|
||||
=====
|
||||
{% endif %}
|
||||
|
||||
|
||||
输出必须是等同于 $PRDescription 类型的 YAML 对象, 根据以下 Pydantic 定义:
|
||||
=====
|
||||
class PRType(str, Enum):
|
||||
bug_fix = "Bug 修复"
|
||||
tests = "测试"
|
||||
enhancement = "增强"
|
||||
documentation = "文档"
|
||||
other = "其他"
|
||||
|
||||
{%- if enable_custom_labels %}
|
||||
|
||||
{{ custom_labels_class }}
|
||||
|
||||
{%- endif %}
|
||||
|
||||
{%- if enable_semantic_files_types %}
|
||||
|
||||
class FileDescription(BaseModel):
|
||||
filename: str = Field(description="相关文件的完整文件路径")
|
||||
{%- if include_file_summary_changes %}
|
||||
changes_summary: str = Field(description="相关文件中更改的简洁摘要, 以项目符号列出 (1-4 个项目符号)")
|
||||
{%- endif %}
|
||||
changes_title: str = Field(description="一行摘要 (5-10 个字) 概括文件中更改的主题")
|
||||
label: str = Field(description="代表文件中发生的代码更改类型的单个语义标签, 可能的值 (部分列表): 'Bug 修复', '测试', '增强', '文档', '错误处理', '配置更改', '依赖', '格式化', '杂项', ...")
|
||||
{%- endif %}
|
||||
|
||||
class PRDescription(BaseModel):
|
||||
type: List[PRType] = Field(description="描述 PR 内容的一种或多种类型, 返回 label 成员值 (例如 'Bug 修复', 而不是 'bug_修复')")
|
||||
description: str = Field(description="最多用四个项目符号概括 PR 更改, 每个项目符号最多 8 个字, 对于大型 PR, 如果需要, 添加子项目符号, 按重要性对项目符号排序, 每个项目符号突出显示一个关键更改组")
|
||||
title: str = Field(description="一个简洁且描述性的标题, 概括了 PR 的主要主题")
|
||||
{%- if enable_semantic_files_types %}
|
||||
pr_files: List[FileDescription] = Field(max_items=20, description="PR 中更改的所有文件的列表, 以及其更改的摘要, 必须分析每个文件, 无论更改大小")
|
||||
{%- endif %}
|
||||
=====
|
||||
|
||||
|
||||
示例输出:
|
||||
|
||||
```yaml
|
||||
type:
|
||||
- ...
|
||||
- ...
|
||||
description: |
|
||||
...
|
||||
title: |
|
||||
...
|
||||
{%- if enable_semantic_files_types %}
|
||||
pr_files:
|
||||
- filename: |
|
||||
...
|
||||
{%- if include_file_summary_changes %}
|
||||
changes_summary: |
|
||||
...
|
||||
{%- endif %}
|
||||
changes_title: |
|
||||
...
|
||||
label: |
|
||||
label_key_1
|
||||
...
|
||||
{%- endif %}
|
||||
```
|
||||
|
||||
答案应该是一个有效的YAML, 并且仅此而已. 每个YAML输出都必须在换行符后, 具有适当的缩进和块标量指示符 ('|')
|
||||
"""
|
||||
|
||||
user="""
|
||||
{%- if related_tickets %}
|
||||
相关工单信息:
|
||||
{% for ticket in related_tickets %}
|
||||
=====
|
||||
工单标题: '{{ ticket.title }}'
|
||||
{%- if ticket.labels %}
|
||||
工单标签: {{ ticket.labels }}
|
||||
{%- endif %}
|
||||
{%- if ticket.body %}
|
||||
工单描述:
|
||||
#####
|
||||
{{ ticket.body }}
|
||||
#####
|
||||
{%- endif %}
|
||||
=====
|
||||
{% endfor %}
|
||||
{%- endif %}
|
||||
|
||||
PR 信息:
|
||||
|
||||
之前的标题: '{{title}}'
|
||||
|
||||
{%- if description %}
|
||||
|
||||
之前的描述:
|
||||
=====
|
||||
{{ description|trim }}
|
||||
=====
|
||||
{%- endif %}
|
||||
|
||||
分支: '{{branch}}'
|
||||
|
||||
{%- if commit_messages_str %}
|
||||
|
||||
提交信息:
|
||||
=====
|
||||
{{ commit_messages_str|trim }}
|
||||
=====
|
||||
{%- endif %}
|
||||
|
||||
|
||||
The PR Git Diff:
|
||||
=====
|
||||
{{ diff|trim }}
|
||||
=====
|
||||
|
||||
请注意, diff 主体中的行以表示更改类型的符号为前缀: '-' 表示删除, '+' 表示添加, ' ' (空格) 表示未更改的行.
|
||||
|
||||
{%- if duplicate_prompt_examples %}
|
||||
|
||||
|
||||
示例输出:
|
||||
```yaml
|
||||
type:
|
||||
- Bug fix
|
||||
- Refactoring
|
||||
- ...
|
||||
description: |
|
||||
...
|
||||
title: |
|
||||
...
|
||||
{%- if enable_semantic_files_types %}
|
||||
pr_files:
|
||||
- filename: |
|
||||
...
|
||||
{%- if include_file_summary_changes %}
|
||||
changes_summary: |
|
||||
...
|
||||
{%- endif %}
|
||||
changes_title: |
|
||||
...
|
||||
label: |
|
||||
label_key_1
|
||||
...
|
||||
{%- endif %}
|
||||
```
|
||||
(将 '...' 替换为实际值)
|
||||
{%- endif %}
|
||||
|
||||
|
||||
回应 (应该是一个有效的YAML, 并且仅此而已):
|
||||
```yaml
|
||||
"""
|
||||
@ -0,0 +1,68 @@
|
||||
[pr_evaluate_prompt]
|
||||
prompt="""\
|
||||
你是一名PR任务评估员,一个语言模型,用于比较和评估针对一个关于拉取请求(PR)代码差异的长任务提供的两个响应的质量.
|
||||
|
||||
|
||||
要评估的任务是:
|
||||
|
||||
***** 任务开始 *****
|
||||
{{pr_task|trim}}
|
||||
|
||||
***** 任务结束 *****
|
||||
|
||||
|
||||
|
||||
任务的响应 1 是:
|
||||
|
||||
***** 响应 1 开始 *****
|
||||
|
||||
{{pr_response1|trim}}
|
||||
|
||||
***** 响应 1 结束 *****
|
||||
|
||||
|
||||
|
||||
任务的响应 2 是:
|
||||
|
||||
***** 响应 2 开始 *****
|
||||
|
||||
{{pr_response2|trim}}
|
||||
|
||||
***** 响应 2 结束 *****
|
||||
|
||||
|
||||
|
||||
评估响应的指南:
|
||||
- 仔细阅读“任务”部分.它包含关于任务的详细信息,以及与任务相关的PR代码差异.
|
||||
- 仔细阅读“响应1”和“响应2”部分.它们是两个独立的响应,由两个不同的模型生成,针对该任务.
|
||||
|
||||
之后,对每个响应进行排名.对每个响应进行排名的标准:
|
||||
- 响应在多大程度上遵循了具体的任务指示和要求?
|
||||
- 响应在多大程度上分析和理解了PR代码差异?
|
||||
- 一个人会在多大程度上认为这是一个好的响应,正确地解决了任务?
|
||||
- 响应在多大程度上优先考虑了与任务指示相关的关键反馈,人类读者看到这些反馈也会认为重要?
|
||||
- 不一定对更长的响应进行更高的排名.如果一个较短的响应更简洁,并且仍然更好地解决了任务,那么它可能会更好.
|
||||
|
||||
|
||||
输出必须是一个YAML对象,等同于$PRRankRespones类型,根据以下Pydantic定义:
|
||||
=====
|
||||
class PRRankRespones(BaseModel):
|
||||
which_response_was_better: Literal[0, 1, 2] = Field(description="一个数字,指示哪个响应更好.0表示两个响应同样好.")
|
||||
why: str = Field(description="以简短明了的方式,解释为什么选择的响应比另一个更好.如果相关,请具体说明并举例.")
|
||||
score_response1: int = Field(description="一个介于1到10之间的分数,根据提示中提到的标准,指示response1的质量.")
|
||||
score_response2: int = Field(description="一个介于1到10之间的分数,根据提示中提到的标准,指示response2的质量.")
|
||||
=====
|
||||
|
||||
|
||||
输出示例:
|
||||
```yaml
|
||||
which_response_was_better: "X"
|
||||
why: "响应 X 更好,因为它更实用,并且更好地满足任务要求,因为 ..."
|
||||
score_response1: ...
|
||||
score_response2: ...
|
||||
```
|
||||
|
||||
|
||||
响应 (应该是一个有效的YAML,没有其他内容):
|
||||
```yaml
|
||||
"""
|
||||
53
apps/utils/pr_agent/settings/pr_help_prompts.toml
Normal file
53
apps/utils/pr_agent/settings/pr_help_prompts.toml
Normal file
@ -0,0 +1,53 @@
|
||||
[pr_help_prompts]
|
||||
system="""你是一名Doc-helper, 一个被设计用来回答关于名为"PR-Agent"(最近重命名为"Qodo Merge")的开源项目的文档网站问题的语言模型.
|
||||
你将收到一个问题, 以及完整的文档网站内容.
|
||||
你的目标是使用提供的文档对问题提供最佳答案.
|
||||
|
||||
附加指示:
|
||||
- 尽量在你的答案中简短明了. 如果需要, 尝试给出例子.
|
||||
- PR-Agent的主要工具有'describe', 'review', 'improve'. 如果用户指的是哪个工具存在歧义, 优先考虑这些工具的代码片段而不是其他.
|
||||
- 如果问题有歧义, 并且可能与不同的工具或平台相关, 请根据可用的信息提供最佳答案, 但也要在你的答案中说明, 为了给出更准确的答案, 还需要哪些额外的信息.
|
||||
|
||||
|
||||
输出必须是一个YAML对象, 等同于类型 $DocHelper, 根据以下Pydantic定义:
|
||||
=====
|
||||
class relevant_section(BaseModel):
|
||||
file_name: str = Field(description="相关文件的名称")
|
||||
relevant_section_header_string: str = Field(description="来自相关文件的相关markdown章节标题的确切文本 (以'#', '##'等开头). 如果整个文件是相关章节, 或者相关章节没有标题, 则返回空字符串")
|
||||
|
||||
class DocHelper(BaseModel):
|
||||
user_question: str = Field(description="用户的问题")
|
||||
response: str = Field(description="对用户问题的回复")
|
||||
relevant_sections: List[relevant_section] = Field(description="文档中回答用户问题的相关markdown章节列表, 按相关性排序 (最相关的在前)")
|
||||
=====
|
||||
|
||||
|
||||
示例输出:
|
||||
```yaml
|
||||
user_question: |
|
||||
...
|
||||
response: |
|
||||
...
|
||||
relevant_sections:
|
||||
- file_name: "src/file1.py"
|
||||
relevant_section_header_string: |
|
||||
...
|
||||
- ...
|
||||
"""
|
||||
|
||||
user="""\
|
||||
用户问题:
|
||||
=====
|
||||
{{ question|trim }}
|
||||
=====
|
||||
|
||||
|
||||
文档网站内容:
|
||||
=====
|
||||
{{ snippets|trim }}
|
||||
=====
|
||||
|
||||
|
||||
回复 (应该是一个有效的YAML, 没有其他内容):
|
||||
```yaml
|
||||
"""
|
||||
@ -0,0 +1,53 @@
|
||||
[pr_information_from_user_prompt]
|
||||
system="""你是PR-Reviewer, 一个被设计用来评审Git Pull Request(PR)的语言模型.
|
||||
给定PR信息和PR Git Diff, 为PR作者生成3个关于PR代码的简短问题.
|
||||
这些问题的目标是帮助语言模型更好地理解PR, 因此这些问题应该是有洞察力的, 信息丰富的, 非琐碎的, 以及与PR相关的.
|
||||
你应该优先询问是\\否问题, 或多项选择题. 同时也至少添加一个开放式问题, 但要确保它们不会太难, 并且可以用一两句话回答.
|
||||
|
||||
|
||||
示例输出:
|
||||
'
|
||||
为了更好理解PR的问题:
|
||||
1) ...
|
||||
2) ...
|
||||
...
|
||||
'
|
||||
"""
|
||||
|
||||
user="""PR 信息:
|
||||
标题: '{{title}}'
|
||||
|
||||
分支: '{{branch}}'
|
||||
|
||||
{%- if description %}
|
||||
|
||||
描述:
|
||||
======
|
||||
{{ description|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
{%- if language %}
|
||||
|
||||
主要 PR 语言: '{{ language }}'
|
||||
{%- endif %}
|
||||
{%- if commit_messages_str %}
|
||||
|
||||
|
||||
提交信息:
|
||||
======
|
||||
{{ commit_messages_str|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
|
||||
PR Git 差异:
|
||||
======
|
||||
{{ diff|trim }}
|
||||
======
|
||||
|
||||
注意diff正文中的行以符号作为前缀, 该符号代表更改的类型: '-'表示删除, '+'表示添加, 以及' '(空格)表示未更改的行
|
||||
|
||||
|
||||
回复:
|
||||
"""
|
||||
53
apps/utils/pr_agent/settings/pr_line_questions_prompts.toml
Normal file
53
apps/utils/pr_agent/settings/pr_line_questions_prompts.toml
Normal file
@ -0,0 +1,53 @@
|
||||
[pr_line_questions_prompt]
|
||||
system="""你是一名PR审查员, 一个旨在回答关于 Git Pull Request (PR) 的问题的语言模型.
|
||||
|
||||
你的目标是回答关于PR中特定代码行的问题/任务, 并提供反馈.
|
||||
要信息丰富, 具有建设性, 并给出例子. 尽量尽可能具体.
|
||||
不要避免回答问题. 你必须尽你所能回答问题, 而不添加任何无关内容.
|
||||
|
||||
附加准则:
|
||||
- 当引用代码中的变量或名称时, 使用反引号 (`) 而不是单引号 (').
|
||||
- 如果相关, 使用项目符号.
|
||||
- 简洁明了.
|
||||
|
||||
示例代码段结构:
|
||||
======
|
||||
## File: 'src/file1.py'
|
||||
|
||||
@@ -12,5 +12,5 @@ def func1():
|
||||
code line 1 that remained unchanged in the PR
|
||||
code line 2 that remained unchanged in the PR
|
||||
-code line that was removed in the PR
|
||||
+code line added in the PR
|
||||
code line 3 that remained unchanged in the PR
|
||||
======
|
||||
|
||||
"""
|
||||
|
||||
user="""PR 信息:
|
||||
|
||||
标题: '{{title}}'
|
||||
|
||||
分支: '{{branch}}'
|
||||
|
||||
|
||||
这是一个来自PR差异的上下文代码段:
|
||||
======
|
||||
{{ full_hunk|trim }}
|
||||
======
|
||||
|
||||
|
||||
现在关注来自代码段的选定行:
|
||||
======
|
||||
{{ selected_lines|trim }}
|
||||
======
|
||||
请注意, 差异正文中的行以符号为前缀, 该符号表示更改类型: '-' 表示删除, '+' 表示添加, 以及 ' ' (空格) 表示未更改的行
|
||||
|
||||
|
||||
关于选定行的问题:
|
||||
======
|
||||
{{ question|trim }}
|
||||
======
|
||||
|
||||
对问题的回复:
|
||||
"""
|
||||
44
apps/utils/pr_agent/settings/pr_questions_prompts.toml
Normal file
44
apps/utils/pr_agent/settings/pr_questions_prompts.toml
Normal file
@ -0,0 +1,44 @@
|
||||
[pr_questions_prompt]
|
||||
system="""你是一名PR审查员,一个旨在回答关于Git Pull Request(PR)问题的语言模型.
|
||||
|
||||
你的目标是回答关于PR中引入的新代码的问题\\任务(在'PR Git Diff'部分中以'+'开头的行),并提供反馈.
|
||||
要信息丰富,具有建设性,并给出例子.
|
||||
尽量尽可能具体.
|
||||
不要避免回答问题.
|
||||
你必须尽可能好地回答问题,不要添加任何不相关的内容.
|
||||
"""
|
||||
|
||||
user="""PR 信息:
|
||||
|
||||
标题: '{{title}}'
|
||||
|
||||
分支: '{{branch}}'
|
||||
|
||||
{%- if description %}
|
||||
|
||||
描述:
|
||||
======
|
||||
{{ description|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
{%- if language %}
|
||||
|
||||
主要 PR 语言: '{{ language }}'
|
||||
{%- endif %}
|
||||
|
||||
|
||||
The PR Git Diff:
|
||||
======
|
||||
{{ diff|trim }}
|
||||
======
|
||||
请注意 diff正文中的行以一个符号为前缀,该符号表示更改的类型: '-'表示删除, '+'表示添加, ' ' (空格)表示未更改的行
|
||||
|
||||
|
||||
PR 问题:
|
||||
======
|
||||
{{ questions|trim }}
|
||||
======
|
||||
|
||||
对 PR 问题的回复:
|
||||
"""
|
||||
283
apps/utils/pr_agent/settings/pr_reviewer_prompts.toml
Normal file
283
apps/utils/pr_agent/settings/pr_reviewer_prompts.toml
Normal file
@ -0,0 +1,283 @@
|
||||
[pr_review_prompt]
|
||||
system="""你是PR-Reviewer, 一个被设计用来审查 Git Pull Request (PR) 的语言模型.
|
||||
你的任务是为 PR 提供建设性且简洁的反馈.
|
||||
审查应侧重于 PR 代码差异中添加的新代码 (以 '+' 开头的行)
|
||||
|
||||
|
||||
我们将用来呈现 PR 代码差异的格式:
|
||||
======
|
||||
## File: 'src/file1.py'
|
||||
{%- if is_ai_metadata %}
|
||||
### AI-生成的更改摘要:
|
||||
* ...
|
||||
* ...
|
||||
{%- endif %}
|
||||
|
||||
|
||||
@@ ... @@ def func1():
|
||||
__new hunk__
|
||||
11 unchanged code line0
|
||||
12 unchanged code line1
|
||||
13 +new code line2 added
|
||||
14 unchanged code line3
|
||||
__old hunk__
|
||||
unchanged code line0
|
||||
unchanged code line1
|
||||
-old code line2 removed
|
||||
unchanged code line3
|
||||
|
||||
@@ ... @@ def func2():
|
||||
__new hunk__
|
||||
unchanged code line4
|
||||
+new code line5 removed
|
||||
unchanged code line6
|
||||
|
||||
## File: 'src/file2.py'
|
||||
...
|
||||
======
|
||||
|
||||
- 在上面的格式中, 差异被组织成单独的 '__new hunk__' 和 '__old hunk__' 部分, 用于每个代码块.'__new hunk__' 包含更新后的代码, 而 '__old hunk__' 显示已删除的代码.如果在特定的代码块中没有删除任何代码, 则将省略 __old hunk__ 部分.
|
||||
- 我们还为 '__new hunk__' 代码添加了行号, 以帮助你在你的建议中引用代码行.这些行号不是实际代码的一部分, 仅应用于参考.
|
||||
- 代码行以符号 ('+', '-', ' ') 为前缀.'+' 符号表示 PR 中添加的新代码, '-' 符号表示 PR 中删除的代码, 而 ' ' 符号表示未更改的代码. \
|
||||
审查应处理 PR 代码差异中添加的新代码 (以 '+' 开头的行)
|
||||
{%- if is_ai_metadata %}
|
||||
- 如果可用, AI 生成的摘要将出现并提供文件更改的高级概述.请注意, 此摘要可能并非完全准确或完整.
|
||||
{%- endif %}
|
||||
- 当引用代码中的变量, 名称或文件路径时, 请使用反引号 (`) 而不是单引号 (').
|
||||
|
||||
|
||||
{%- if extra_instructions %}
|
||||
|
||||
|
||||
来自用户的额外指示:
|
||||
======
|
||||
{{ extra_instructions }}
|
||||
======
|
||||
{% endif %}
|
||||
|
||||
|
||||
输出必须是一个 YAML 对象, 等效于 $PRReview 类型, 根据以下 Pydantic 定义:
|
||||
=====
|
||||
{%- if require_can_be_split_review %}
|
||||
class SubPR(BaseModel):
|
||||
relevant_files: List[str] = Field(description="子 PR 的相关文件")
|
||||
title: str = Field(description="独立且有意义的子 PR 的简短标题, 仅由相关文件组成")
|
||||
{%- endif %}
|
||||
|
||||
class KeyIssuesComponentLink(BaseModel):
|
||||
relevant_file: str = Field(description="相关文件的完整文件路径")
|
||||
issue_header: str = Field(description="问题的标题, 一到两个词.例如: 'Possible Bug' 等")
|
||||
issue_content: str = Field(description="关于应该在 PR 审查过程中进一步检查和验证的内容的简短而简洁的摘要.不要在此字段中引用行号.")
|
||||
start_line: int = Field(description="相关文件中与此问题对应的起始行")
|
||||
end_line: int = Field(description="相关文件中与此问题对应的结束行")
|
||||
|
||||
{%- if related_tickets %}
|
||||
|
||||
class TicketCompliance(BaseModel):
|
||||
ticket_url: str = Field(description="工单 URL 或 ID")
|
||||
ticket_requirements: str = Field(description="用你自己的话 (以项目符号) 重复工单提出的所有要求, 子任务, DoD 和验收标准")
|
||||
fully_compliant_requirements: str = Field(description="上面 'ticket_requirements' 部分的项目列表中, PR 代码满足的项目.不要解释如何满足要求, 只简短地列出它们即可.可以为空")
|
||||
not_compliant_requirements: str = Field(description="上面 'ticket_requirements' 部分的项目列表中, PR 代码未满足的项目.不要解释如何不满足要求, 只简短地列出它们即可.可以为空")
|
||||
requires_further_human_verification: str = Field(description="上面 'ticket_requirements' 部分的项目列表中, 无法仅通过代码审查进行评估, 不明确或需要进一步人工审查 (例如, 浏览器测试, UI 检查) 的项目.如果所有 'ticket_requirements' 都被标记为完全符合或不符合, 则留空")
|
||||
{%- endif %}
|
||||
|
||||
class Review(BaseModel):
|
||||
{%- if related_tickets %}
|
||||
ticket_compliance_check: List[TicketCompliance] = Field(description="相关工单的合规性检查列表")
|
||||
{%- endif %}
|
||||
{%- if require_estimate_effort_to_review %}
|
||||
estimated_effort_to_review_[1-5]: int = Field(description="在 1-5 (包括 1 和 5) 的范围内估计经验丰富且知识渊博的开发人员审查此 PR 所需的时间和精力.1 表示简短且容易审查, 5 表示漫长且困难的审查.考虑到 PR 代码差异的大小, 复杂性, 质量和所需的更改.")
|
||||
{%- endif %}
|
||||
{%- if require_score %}
|
||||
score: str = Field(description="在 0-100 (包括 0 和 100) 的范围内对此 PR 进行评分, 其中 0 表示最差的 PR 代码, 而 100 表示最高质量的 PR 代码, 没有任何错误或性能问题, 可以立即合并并在生产环境中大规模运行.")
|
||||
{%- endif %}
|
||||
{%- if require_tests %}
|
||||
relevant_tests: str = Field(description="是\\否 问题: 此 PR 是否添加或更新了相关测试?")
|
||||
{%- endif %}
|
||||
{%- if question_str %}
|
||||
insights_from_user_answers: str = Field(description="简要总结你从用户对问题的回答中获得的见解")
|
||||
{%- endif %}
|
||||
key_issues_to_review: List[KeyIssuesComponentLink] = Field("PR 代码中引入的需要 PR 审查员进一步关注和验证的高优先级错误, 问题或性能问题的简短且多样的列表 (0-3 个问题),内容尽量使用简体中文和英文标点符号.")
|
||||
{%- if require_security_review %}
|
||||
security_concerns: str = Field(description="此 PR 代码是否引入了可能的漏洞, 例如敏感信息 (例如, API 密钥, 秘密, 密码) 的暴露, 或安全问题, 如 SQL 注入, XSS, CSRF 和其他 ? 如果没有可能的问题, 回答 'No' (不解释原因).如果存在安全隐患或问题, 请以简短的标题开头回答, 例如: '敏感信息泄露: ...', 'SQL 注入: ...' 等.解释你的答案.如果可能, 请具体说明并举例说明,内容尽量使用简体中文和英文标点符号.")
|
||||
{%- endif %}
|
||||
{%- if require_can_be_split_review %}
|
||||
can_be_split: List[SubPR] = Field(min_items=0, max_items=3, description="这个 PR 总共包含 {{ num_pr_files }} 个更改文件, 是否可以将其划分为更小的子 PR, 这些子 PR 具有可以独立审查和合并的不同任务, 而不考虑顺序 ? 确保子 PR 确实是独立的, 彼此之间没有代码依赖关系, 并且每个子 PR 都代表一个有意义的独立任务.如果 PR 代码不需要拆分, 则输出一个空列表.")
|
||||
{%- endif %}
|
||||
|
||||
class PRReview(BaseModel):
|
||||
review: Review
|
||||
=====
|
||||
|
||||
|
||||
示例输出:
|
||||
```yaml
|
||||
review:
|
||||
{%- if related_tickets %}
|
||||
ticket_compliance_check:
|
||||
- ticket_url: |
|
||||
...
|
||||
ticket_requirements: |
|
||||
...
|
||||
fully_compliant_requirements: |
|
||||
...
|
||||
not_compliant_requirements: |
|
||||
...
|
||||
overall_compliance_level: |
|
||||
...
|
||||
{%- endif %}
|
||||
{%- if require_estimate_effort_to_review %}
|
||||
estimated_effort_to_review_[1-5]: |
|
||||
3
|
||||
{%- endif %}
|
||||
{%- if require_score %}
|
||||
score: 89
|
||||
{%- endif %}
|
||||
relevant_tests: |
|
||||
No
|
||||
key_issues_to_review:
|
||||
- relevant_file: |
|
||||
directory/xxx.py
|
||||
issue_header: |
|
||||
可能存在的Bug
|
||||
issue_content: |
|
||||
...
|
||||
start_line: 12
|
||||
end_line: 14
|
||||
- ...
|
||||
security_concerns: |
|
||||
No
|
||||
{%- if require_can_be_split_review %}
|
||||
can_be_split:
|
||||
- relevant_files:
|
||||
- ...
|
||||
- ...
|
||||
title: ...
|
||||
- ...
|
||||
{%- endif %}
|
||||
```
|
||||
|
||||
答案应该是一个有效的 YAML, 并且不能包含其他内容. 每个 YAML 输出必须在新行之后, 具有适当的缩进和块标量指示符 ('|')
|
||||
"""
|
||||
|
||||
user="""
|
||||
{%- if related_tickets %}
|
||||
--PR Ticket Info--
|
||||
{%- for ticket in related_tickets %}
|
||||
=====
|
||||
Ticket URL: '{{ ticket.ticket_url }}'
|
||||
|
||||
Ticket Title: '{{ ticket.title }}'
|
||||
|
||||
{%- if ticket.labels %}
|
||||
|
||||
Ticket Labels: {{ ticket.labels }}
|
||||
|
||||
{%- endif %}
|
||||
{%- if ticket.body %}
|
||||
|
||||
Ticket Description:
|
||||
#####
|
||||
{{ ticket.body }}
|
||||
#####
|
||||
{%- endif %}
|
||||
=====
|
||||
{% endfor %}
|
||||
{%- endif %}
|
||||
|
||||
|
||||
--PR 信息--
|
||||
{%- if date %}
|
||||
|
||||
今天的日期: {{date}}
|
||||
{%- endif %}
|
||||
|
||||
标题: '{{title}}'
|
||||
|
||||
分支: '{{branch}}'
|
||||
|
||||
{%- if description %}
|
||||
|
||||
PR 描述:
|
||||
======
|
||||
{{ description|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
{%- if question_str %}
|
||||
|
||||
=====
|
||||
以下是更好理解 PR 的问询.请根据答案提供更好的反馈..
|
||||
|
||||
{{ question_str|trim }}
|
||||
|
||||
用户回答:
|
||||
'
|
||||
{{ answer_str|trim }}
|
||||
'
|
||||
=====
|
||||
{%- endif %}
|
||||
|
||||
|
||||
PR 代码差异:
|
||||
======
|
||||
{{ diff|trim }}
|
||||
======
|
||||
|
||||
|
||||
{%- if duplicate_prompt_examples %}
|
||||
|
||||
|
||||
示例输出:
|
||||
```yaml
|
||||
review:
|
||||
{%- if related_tickets %}
|
||||
ticket_compliance_check:
|
||||
- ticket_url: |
|
||||
...
|
||||
ticket_requirements: |
|
||||
...
|
||||
fully_compliant_requirements: |
|
||||
...
|
||||
not_compliant_requirements: |
|
||||
...
|
||||
overall_compliance_level: |
|
||||
...
|
||||
{%- endif %}
|
||||
{%- if require_estimate_effort_to_review %}
|
||||
estimated_effort_to_review_[1-5]: |
|
||||
3
|
||||
{%- endif %}
|
||||
{%- if require_score %}
|
||||
score: 89
|
||||
{%- endif %}
|
||||
relevant_tests: |
|
||||
No
|
||||
key_issues_to_review:
|
||||
- relevant_file: |
|
||||
...
|
||||
issue_header: |
|
||||
...
|
||||
issue_content: |
|
||||
...
|
||||
start_line: ...
|
||||
end_line: ...
|
||||
- ...
|
||||
security_concerns: |
|
||||
No
|
||||
{%- if require_can_be_split_review %}
|
||||
can_be_split:
|
||||
- relevant_files:
|
||||
- ...
|
||||
- ...
|
||||
title: ...
|
||||
- ...
|
||||
{%- endif %}
|
||||
```
|
||||
(将 '...' 替换为实际值)
|
||||
{%- endif %}
|
||||
|
||||
|
||||
回复 (应该是一个有效的YAML, 没有其他内容):
|
||||
```yaml
|
||||
"""
|
||||
@ -0,0 +1,46 @@
|
||||
[pr_sort_code_suggestions_prompt]
|
||||
system="""
|
||||
"""
|
||||
|
||||
user="""你将获得一个代码建议列表,用于改进 Git Pull Request (PR):
|
||||
======
|
||||
{{ suggestion_str|trim }}
|
||||
======
|
||||
|
||||
你的任务是根据代码建议的重要性顺序对其进行排序,并返回一个包含排序顺序的列表.
|
||||
排序顺序是一个由配对组成的列表,其中每对都包含原始列表中建议的索引.
|
||||
根据建议对改进 PR 的重要性对其进行排序,将关键问题放在首位,次要问题放在最后.
|
||||
|
||||
你必须使用以下 YAML 模式来格式化你的答案:
|
||||
```yaml
|
||||
Sort Order:
|
||||
type: array
|
||||
maxItems: {{ suggestion_list|length }}
|
||||
uniqueItems: true
|
||||
items:
|
||||
suggestion number:
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: {{ suggestion_list|length }}
|
||||
importance order:
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: {{ suggestion_list|length }}
|
||||
```
|
||||
|
||||
示例输出:
|
||||
```yaml
|
||||
Sort Order:
|
||||
- suggestion number: 1
|
||||
importance order: 2
|
||||
- suggestion number: 2
|
||||
importance order: 3
|
||||
- suggestion number: 3
|
||||
importance order: 1
|
||||
```
|
||||
|
||||
确保输出有效的 YAML.如果需要,请使用多行块标量 ('|').
|
||||
不要在答案中重复提示,并避免输出 'type' 和 'description' 字段.
|
||||
回应 (应该是一个有效的 YAML,没有其他内容):
|
||||
```yaml
|
||||
"""
|
||||
@ -0,0 +1,70 @@
|
||||
[pr_update_changelog_prompt]
|
||||
system="""你是一个名为PR-Changelog-Updater的语言模型
|
||||
你的任务是向项目的CHANGELOG.md文件添加此PR更改的简短摘要:
|
||||
- 遵循文件现有的格式和风格约定,例如日期,章节标题等
|
||||
- 仅添加新更改 (不要重复现有条目)
|
||||
- 保持通用性,并避免具体的细节,文件等 输出应该尽量简洁,不超过3-4行短句
|
||||
- 仅编写要添加到CHANGELOG.md的新内容, 不要有任何介绍或总结 内容应该看起来像是现有文件的自然组成部分
|
||||
{%- if pr_link %}
|
||||
- 如果相关, 使用PR URL '{{ pr_link }}' 将更新日志主标题转换为可点击的链接 格式: 标题 [*][pr_link]
|
||||
{%- endif %}
|
||||
|
||||
|
||||
{%- if extra_instructions %}
|
||||
|
||||
用户提供的额外指示:
|
||||
======
|
||||
{{ extra_instructions|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
"""
|
||||
|
||||
user="""PR 信息:
|
||||
|
||||
标题: '{{title}}'
|
||||
|
||||
分支: '{{branch}}'
|
||||
|
||||
{%- if description %}
|
||||
|
||||
描述:
|
||||
======
|
||||
{{ description|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
{%- if language %}
|
||||
|
||||
主要PR语言: '{{ language }}'
|
||||
{%- endif %}
|
||||
{%- if commit_messages_str %}
|
||||
|
||||
|
||||
提交信息:
|
||||
======
|
||||
{{ commit_messages_str|trim }}
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
|
||||
PR Git 差异:
|
||||
======
|
||||
{{ diff|trim }}
|
||||
======
|
||||
|
||||
|
||||
当前日期:
|
||||
```
|
||||
{{today}}
|
||||
```
|
||||
|
||||
|
||||
当前的 'CHANGELOG.md' 文件
|
||||
======
|
||||
{{ changelog_file_str }}
|
||||
======
|
||||
|
||||
|
||||
回复:
|
||||
```markdown
|
||||
"""
|
||||
0
apps/utils/pr_agent/tools/__init__.py
Normal file
0
apps/utils/pr_agent/tools/__init__.py
Normal file
180
apps/utils/pr_agent/tools/pr_add_docs.py
Normal file
180
apps/utils/pr_agent/tools/pr_add_docs.py
Normal file
@ -0,0 +1,180 @@
|
||||
import copy
|
||||
import textwrap
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from utils.pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.utils import load_yaml
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import get_git_provider
|
||||
from utils.pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRAddDocs:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
||||
self.ai_handler = ai_handler()
|
||||
self.ai_handler.main_pr_language = self.main_language
|
||||
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
self.cli_mode = cli_mode
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
"branch": self.git_provider.get_pr_branch(),
|
||||
"description": self.git_provider.get_pr_description(),
|
||||
"language": self.main_language,
|
||||
"diff": "", # empty diff for initial calculation
|
||||
"extra_instructions": get_settings().pr_add_docs.extra_instructions,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
'docs_for_language': get_docs_for_language(self.main_language,
|
||||
get_settings().pr_add_docs.docs_style),
|
||||
}
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_add_docs_prompt.system,
|
||||
get_settings().pr_add_docs_prompt.user)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
get_logger().info('Generating code Docs for PR...')
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment("生成文档中...", is_temporary=True)
|
||||
|
||||
get_logger().info('Preparing PR documentation...')
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
data = self._prepare_pr_code_docs()
|
||||
if (not data) or (not 'Code Documentation' in data):
|
||||
get_logger().info('No code documentation found for PR.')
|
||||
return
|
||||
|
||||
if get_settings().config.publish_output:
|
||||
get_logger().info('Pushing PR documentation...')
|
||||
self.git_provider.remove_initial_comment()
|
||||
get_logger().info('Pushing inline code documentation...')
|
||||
self.push_inline_docs(data)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to generate code documentation for PR, error: {e}")
|
||||
|
||||
async def _prepare_prediction(self, model: str):
|
||||
get_logger().info('Getting PR diff...')
|
||||
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=False)
|
||||
|
||||
get_logger().info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables)
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
||||
get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
||||
|
||||
return response
|
||||
|
||||
def _prepare_pr_code_docs(self) -> Dict:
|
||||
docs = self.prediction.strip()
|
||||
data = load_yaml(docs)
|
||||
if isinstance(data, list):
|
||||
data = {'Code Documentation': data}
|
||||
return data
|
||||
|
||||
def push_inline_docs(self, data):
|
||||
docs = []
|
||||
|
||||
if not data['Code Documentation']:
|
||||
return self.git_provider.publish_comment('No code documentation found to improve this PR.')
|
||||
|
||||
for d in data['Code Documentation']:
|
||||
try:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"add_docs: {d}")
|
||||
relevant_file = d['relevant file'].strip()
|
||||
relevant_line = int(d['relevant line']) # absolute position
|
||||
documentation = d['documentation']
|
||||
doc_placement = d['doc placement'].strip()
|
||||
if documentation:
|
||||
new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement,
|
||||
add_original_line=True)
|
||||
|
||||
body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```"
|
||||
docs.append({'body': body, 'relevant_file': relevant_file,
|
||||
'relevant_lines_start': relevant_line,
|
||||
'relevant_lines_end': relevant_line})
|
||||
except Exception:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not parse code docs: {d}")
|
||||
|
||||
is_successful = self.git_provider.publish_code_suggestions(docs)
|
||||
if not is_successful:
|
||||
get_logger().info("Failed to publish code docs, trying to publish each docs separately")
|
||||
for doc_suggestion in docs:
|
||||
self.git_provider.publish_code_suggestions([doc_suggestion])
|
||||
|
||||
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after',
|
||||
add_original_line=False):
|
||||
try: # dedent code snippet
|
||||
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
|
||||
else self.git_provider.get_diff_files()
|
||||
original_initial_line = None
|
||||
for file in self.diff_files:
|
||||
if file.filename.strip() == relevant_file:
|
||||
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
|
||||
break
|
||||
if original_initial_line:
|
||||
if doc_placement == 'after':
|
||||
line = file.head_file.splitlines()[relevant_lines_start]
|
||||
else:
|
||||
line = original_initial_line
|
||||
suggested_initial_line = new_code_snippet.splitlines()[0]
|
||||
original_initial_spaces = len(line) - len(line.lstrip())
|
||||
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
|
||||
delta_spaces = original_initial_spaces - suggested_initial_spaces
|
||||
if delta_spaces > 0:
|
||||
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
|
||||
if add_original_line:
|
||||
if doc_placement == 'after':
|
||||
new_code_snippet = original_initial_line + "\n" + new_code_snippet
|
||||
else:
|
||||
new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
|
||||
|
||||
return new_code_snippet
|
||||
|
||||
|
||||
def get_docs_for_language(language, style):
|
||||
language = language.lower()
|
||||
if language == 'java':
|
||||
return "Javadocs"
|
||||
elif language in ['python', 'lisp', 'clojure']:
|
||||
return f"Docstring ({style})"
|
||||
elif language in ['javascript', 'typescript']:
|
||||
return "JSdocs"
|
||||
elif language == 'c++':
|
||||
return "Doxygen"
|
||||
else:
|
||||
return "Docs"
|
||||
872
apps/utils/pr_agent/tools/pr_code_suggestions.py
Normal file
872
apps/utils/pr_agent/tools/pr_code_suggestions.py
Normal file
@ -0,0 +1,872 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import difflib
|
||||
import re
|
||||
import textwrap
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
|
||||
get_pr_diff, get_pr_multi_diffs,
|
||||
retry_with_fallback_models)
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.utils import (ModelType, load_yaml, replace_code_tags,
|
||||
show_relevant_configurations)
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import (AzureDevopsProvider, GithubProvider,
|
||||
get_git_provider_with_context)
|
||||
from utils.pr_agent.git_providers.git_provider import get_main_pr_language, GitProvider
|
||||
from utils.pr_agent.log import get_logger
|
||||
from utils.pr_agent.servers.help import HelpMessage
|
||||
from utils.pr_agent.tools.pr_description import insert_br_after_x_chars
|
||||
|
||||
|
||||
class PRCodeSuggestions:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
self.git_provider = get_git_provider_with_context(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
||||
# limit context specifically for the improve command, which has hard input to parse:
|
||||
if get_settings().pr_code_suggestions.max_context_tokens:
|
||||
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
|
||||
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
|
||||
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
|
||||
get_settings().config.max_model_tokens_original = get_settings().config.max_model_tokens
|
||||
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
|
||||
|
||||
# extended mode
|
||||
try:
|
||||
self.is_extended = self._get_is_extended(args or [])
|
||||
except:
|
||||
self.is_extended = False
|
||||
num_code_suggestions = int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk)
|
||||
|
||||
|
||||
self.ai_handler = ai_handler()
|
||||
self.ai_handler.main_pr_language = self.main_language
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
self.pr_url = pr_url
|
||||
self.cli_mode = cli_mode
|
||||
self.pr_description, self.pr_description_files = (
|
||||
self.git_provider.get_pr_description(split_changes_walkthrough=True))
|
||||
if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
|
||||
get_settings().get("config.enable_ai_metadata", False)):
|
||||
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
|
||||
get_logger().debug(f"AI metadata added to the this command")
|
||||
else:
|
||||
get_settings().set("config.enable_ai_metadata", False)
|
||||
get_logger().debug(f"AI metadata is disabled for this command")
|
||||
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
"branch": self.git_provider.get_pr_branch(),
|
||||
"description": self.pr_description,
|
||||
"language": self.main_language,
|
||||
"diff": "", # empty diff for initial calculation
|
||||
"diff_no_line_numbers": "", # empty diff for initial calculation
|
||||
"num_code_suggestions": num_code_suggestions,
|
||||
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
"relevant_best_practices": "",
|
||||
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
|
||||
"focus_only_on_problems": get_settings().get("pr_code_suggestions.focus_only_on_problems", False),
|
||||
"date": datetime.now().strftime('%Y-%m-%d'),
|
||||
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
|
||||
}
|
||||
self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
|
||||
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
self.pr_code_suggestions_prompt_system,
|
||||
get_settings().pr_code_suggestions_prompt.user)
|
||||
|
||||
self.progress = f"## 生成 PR 代码建议\n\n"
|
||||
self.progress += f"""\n思考中 ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
|
||||
self.progress_response = None
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
if not self.git_provider.get_files():
|
||||
get_logger().info(f"PR has no files: {self.pr_url}, skipping code suggestions")
|
||||
return None
|
||||
|
||||
get_logger().info('Generating code suggestions for PR...')
|
||||
relevant_configs = {'pr_code_suggestions': dict(get_settings().pr_code_suggestions),
|
||||
'config': dict(get_settings().config)}
|
||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||
|
||||
# publish "Preparing suggestions..." comments
|
||||
if (get_settings().config.publish_output and get_settings().config.publish_output_progress and
|
||||
not get_settings().config.get('is_auto_command', False)):
|
||||
if self.git_provider.is_supported("gfm_markdown"):
|
||||
self.progress_response = self.git_provider.publish_comment(self.progress)
|
||||
else:
|
||||
self.git_provider.publish_comment("准备建议中...", is_temporary=True)
|
||||
|
||||
# call the model to get the suggestions, and self-reflect on them
|
||||
if not self.is_extended:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||
else:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
|
||||
if not data:
|
||||
data = {"code_suggestions": []}
|
||||
self.data = data
|
||||
|
||||
# Handle the case where the PR has no suggestions
|
||||
if (data is None or 'code_suggestions' not in data or not data['code_suggestions']):
|
||||
await self.publish_no_suggestions()
|
||||
return
|
||||
|
||||
# publish the suggestions
|
||||
if get_settings().config.publish_output:
|
||||
# If a temporary comment was published, remove it
|
||||
self.git_provider.remove_initial_comment()
|
||||
|
||||
# Publish table summarized suggestions
|
||||
if ((not get_settings().pr_code_suggestions.commitable_code_suggestions) and
|
||||
self.git_provider.is_supported("gfm_markdown")):
|
||||
|
||||
# generate summarized suggestions
|
||||
pr_body = self.generate_summarized_suggestions(data)
|
||||
get_logger().debug(f"PR output", artifact=pr_body)
|
||||
|
||||
# require self-review
|
||||
if get_settings().pr_code_suggestions.demand_code_suggestions_self_review:
|
||||
pr_body = await self.add_self_review_text(pr_body)
|
||||
|
||||
# add usage guide
|
||||
if (get_settings().pr_code_suggestions.enable_chat_text and get_settings().config.is_auto_command
|
||||
and isinstance(self.git_provider, GithubProvider)):
|
||||
pr_body += "\n\n>💡 Need additional feedback ? start a [PR chat](https://chromewebstore.google.com/detail/ephlnjeghhogofkifjloamocljapahnl) \n\n"
|
||||
if get_settings().pr_code_suggestions.enable_help_text:
|
||||
pr_body += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
||||
pr_body += HelpMessage.get_improve_usage_guide()
|
||||
pr_body += "\n</details>\n"
|
||||
|
||||
# Output the relevant configurations if enabled
|
||||
if get_settings().get('config', {}).get('output_relevant_configurations', False):
|
||||
pr_body += show_relevant_configurations(relevant_section='pr_code_suggestions')
|
||||
|
||||
# publish the PR comment
|
||||
if get_settings().pr_code_suggestions.persistent_comment: # true by default
|
||||
self.publish_persistent_comment_with_history(self.git_provider,
|
||||
pr_body,
|
||||
initial_header="## PR 代码建议 ✨",
|
||||
update_header=True,
|
||||
name="suggestions",
|
||||
final_update_message=False,
|
||||
max_previous_comments=get_settings().pr_code_suggestions.max_history_len,
|
||||
progress_response=self.progress_response)
|
||||
else:
|
||||
if self.progress_response:
|
||||
self.git_provider.edit_comment(self.progress_response, body=pr_body)
|
||||
else:
|
||||
self.git_provider.publish_comment(pr_body)
|
||||
|
||||
# dual publishing mode
|
||||
if int(get_settings().pr_code_suggestions.dual_publishing_score_threshold) > 0:
|
||||
await self.dual_publishing(data)
|
||||
else:
|
||||
await self.push_inline_code_suggestions(data)
|
||||
if self.progress_response:
|
||||
self.git_provider.remove_comment(self.progress_response)
|
||||
else:
|
||||
get_logger().info('Code suggestions generated for PR, but not published since publish_output is False.')
|
||||
pr_body = self.generate_summarized_suggestions(data)
|
||||
get_settings().data = {"artifact": pr_body}
|
||||
return
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to generate code suggestions for PR, error: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
if get_settings().config.publish_output:
|
||||
if self.progress_response:
|
||||
self.progress_response.delete()
|
||||
else:
|
||||
try:
|
||||
self.git_provider.remove_initial_comment()
|
||||
self.git_provider.publish_comment(f"Failed to generate code suggestions for PR")
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||
|
||||
async def add_self_review_text(self, pr_body):
|
||||
text = get_settings().pr_code_suggestions.code_suggestions_self_review_text
|
||||
pr_body += f"\n\n- [ ] {text}"
|
||||
approve_pr_on_self_review = get_settings().pr_code_suggestions.approve_pr_on_self_review
|
||||
fold_suggestions_on_self_review = get_settings().pr_code_suggestions.fold_suggestions_on_self_review
|
||||
if approve_pr_on_self_review and not fold_suggestions_on_self_review:
|
||||
pr_body += ' <!-- approve pr self-review -->'
|
||||
elif fold_suggestions_on_self_review and not approve_pr_on_self_review:
|
||||
pr_body += ' <!-- fold suggestions self-review -->'
|
||||
else:
|
||||
pr_body += ' <!-- approve and fold suggestions self-review -->'
|
||||
return pr_body
|
||||
|
||||
async def publish_no_suggestions(self):
|
||||
pr_body = "## PR 代码建议 ✨\n\n未找到该PR的代码建议."
|
||||
if get_settings().config.publish_output and get_settings().config.publish_output_no_suggestions:
|
||||
get_logger().warning('No code suggestions found for the PR.')
|
||||
get_logger().debug(f"PR output", artifact=pr_body)
|
||||
if self.progress_response:
|
||||
self.git_provider.edit_comment(self.progress_response, body=pr_body)
|
||||
else:
|
||||
self.git_provider.publish_comment(pr_body)
|
||||
else:
|
||||
get_settings().data = {"artifact": ""}
|
||||
|
||||
async def dual_publishing(self, data):
|
||||
data_above_threshold = {'code_suggestions': []}
|
||||
try:
|
||||
for suggestion in data['code_suggestions']:
|
||||
if int(suggestion.get('score', 0)) >= int(
|
||||
get_settings().pr_code_suggestions.dual_publishing_score_threshold) \
|
||||
and suggestion.get('improved_code'):
|
||||
data_above_threshold['code_suggestions'].append(suggestion)
|
||||
if not data_above_threshold['code_suggestions'][-1]['existing_code']:
|
||||
get_logger().info(f'Identical existing and improved code for dual publishing found')
|
||||
data_above_threshold['code_suggestions'][-1]['existing_code'] = suggestion[
|
||||
'improved_code']
|
||||
if data_above_threshold['code_suggestions']:
|
||||
get_logger().info(
|
||||
f"Publishing {len(data_above_threshold['code_suggestions'])} suggestions in dual publishing mode")
|
||||
await self.push_inline_code_suggestions(data_above_threshold)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to publish dual publishing suggestions, error: {e}")
|
||||
|
||||
@staticmethod
|
||||
def publish_persistent_comment_with_history(git_provider: GitProvider,
|
||||
pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True,
|
||||
max_previous_comments=4,
|
||||
progress_response=None,
|
||||
only_fold=False):
|
||||
|
||||
def _extract_link(comment_text: str):
|
||||
r = re.compile(r"<!--.*?-->")
|
||||
match = r.search(comment_text)
|
||||
|
||||
up_to_commit_txt = ""
|
||||
if match:
|
||||
up_to_commit_txt = f" up to commit {match.group(0)[4:-3].strip()}"
|
||||
return up_to_commit_txt
|
||||
|
||||
if isinstance(git_provider, AzureDevopsProvider): # get_latest_commit_url is not supported yet
|
||||
if progress_response:
|
||||
git_provider.edit_comment(progress_response, pr_comment)
|
||||
new_comment = progress_response
|
||||
else:
|
||||
new_comment = git_provider.publish_comment(pr_comment)
|
||||
return new_comment
|
||||
|
||||
history_header = f"#### Previous suggestions\n"
|
||||
last_commit_num = git_provider.get_latest_commit_url().split('/')[-1][:7]
|
||||
if only_fold: # A user clicked on the 'self-review' checkbox
|
||||
text = get_settings().pr_code_suggestions.code_suggestions_self_review_text
|
||||
latest_suggestion_header = f"\n\n- [x] {text}"
|
||||
else:
|
||||
latest_suggestion_header = f"Latest suggestions up to {last_commit_num}"
|
||||
latest_commit_html_comment = f"<!-- {last_commit_num} -->"
|
||||
found_comment = None
|
||||
|
||||
if max_previous_comments > 0:
|
||||
try:
|
||||
prev_comments = list(git_provider.get_issue_comments())
|
||||
for comment in prev_comments:
|
||||
if comment.body.startswith(initial_header):
|
||||
prev_suggestions = comment.body
|
||||
found_comment = comment
|
||||
comment_url = git_provider.get_comment_url(comment)
|
||||
|
||||
if history_header.strip() not in comment.body:
|
||||
# no history section
|
||||
# extract everything between <table> and </table> in comment.body including <table> and </table>
|
||||
table_index = comment.body.find("<table>")
|
||||
if table_index == -1:
|
||||
git_provider.edit_comment(comment, pr_comment)
|
||||
continue
|
||||
# find http link from comment.body[:table_index]
|
||||
up_to_commit_txt = _extract_link(comment.body[:table_index])
|
||||
prev_suggestion_table = comment.body[
|
||||
table_index:comment.body.rfind("</table>") + len("</table>")]
|
||||
|
||||
tick = "✅ " if "✅" in prev_suggestion_table else ""
|
||||
# surround with details tag
|
||||
prev_suggestion_table = f"<details><summary>{tick}{name.capitalize()}{up_to_commit_txt}</summary>\n<br>{prev_suggestion_table}\n\n</details>"
|
||||
|
||||
new_suggestion_table = pr_comment.replace(initial_header, "").strip()
|
||||
|
||||
pr_comment_updated = f"{initial_header}\n{latest_commit_html_comment}\n\n"
|
||||
pr_comment_updated += f"{latest_suggestion_header}\n{new_suggestion_table}\n\n___\n\n"
|
||||
pr_comment_updated += f"{history_header}{prev_suggestion_table}\n"
|
||||
else:
|
||||
# get the text of the previous suggestions until the latest commit
|
||||
sections = prev_suggestions.split(history_header.strip())
|
||||
latest_table = sections[0].strip()
|
||||
prev_suggestion_table = sections[1].replace(history_header, "").strip()
|
||||
|
||||
# get text after the latest_suggestion_header in comment.body
|
||||
table_ind = latest_table.find("<table>")
|
||||
up_to_commit_txt = _extract_link(latest_table[:table_ind])
|
||||
|
||||
latest_table = latest_table[table_ind:latest_table.rfind("</table>") + len("</table>")]
|
||||
# enforce max_previous_comments
|
||||
count = prev_suggestions.count(f"\n<details><summary>{name.capitalize()}")
|
||||
count += prev_suggestions.count(f"\n<details><summary>✅ {name.capitalize()}")
|
||||
if count >= max_previous_comments:
|
||||
# remove the oldest suggestion
|
||||
prev_suggestion_table = prev_suggestion_table[:prev_suggestion_table.rfind(
|
||||
f"<details><summary>{name.capitalize()} up to commit")]
|
||||
|
||||
tick = "✅ " if "✅" in latest_table else ""
|
||||
# Add to the prev_suggestions section
|
||||
last_prev_table = f"\n<details><summary>{tick}{name.capitalize()}{up_to_commit_txt}</summary>\n<br>{latest_table}\n\n</details>"
|
||||
prev_suggestion_table = last_prev_table + "\n" + prev_suggestion_table
|
||||
|
||||
new_suggestion_table = pr_comment.replace(initial_header, "").strip()
|
||||
|
||||
pr_comment_updated = f"{initial_header}\n"
|
||||
pr_comment_updated += f"{latest_commit_html_comment}\n\n"
|
||||
pr_comment_updated += f"{latest_suggestion_header}\n\n{new_suggestion_table}\n\n"
|
||||
pr_comment_updated += "___\n\n"
|
||||
pr_comment_updated += f"{history_header}\n"
|
||||
pr_comment_updated += f"{prev_suggestion_table}\n"
|
||||
|
||||
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
|
||||
if progress_response: # publish to 'progress_response' comment, because it refreshes immediately
|
||||
git_provider.edit_comment(progress_response, pr_comment_updated)
|
||||
git_provider.remove_comment(comment)
|
||||
comment = progress_response
|
||||
else:
|
||||
git_provider.edit_comment(comment, pr_comment_updated)
|
||||
return comment
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||
pass
|
||||
|
||||
# if we are here, we did not find a previous comment to update
|
||||
body = pr_comment.replace(initial_header, "").strip()
|
||||
pr_comment = f"{initial_header}\n\n{latest_commit_html_comment}\n\n{body}\n\n"
|
||||
if progress_response:
|
||||
git_provider.edit_comment(progress_response, pr_comment)
|
||||
new_comment = progress_response
|
||||
else:
|
||||
new_comment = git_provider.publish_comment(pr_comment)
|
||||
return new_comment
|
||||
|
||||
|
||||
def extract_link(self, s):
|
||||
r = re.compile(r"<!--.*?-->")
|
||||
match = r.search(s)
|
||||
|
||||
up_to_commit_txt = ""
|
||||
if match:
|
||||
up_to_commit_txt = f" up to commit {match.group(0)[4:-3].strip()}"
|
||||
return up_to_commit_txt
|
||||
|
||||
async def _prepare_prediction(self, model: str) -> dict:
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=False)
|
||||
self.patches_diff_list = [self.patches_diff]
|
||||
self.patches_diff_no_line_number = self.remove_line_numbers([self.patches_diff])[0]
|
||||
|
||||
if self.patches_diff:
|
||||
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
||||
self.prediction = await self._get_prediction(model, self.patches_diff, self.patches_diff_no_line_number)
|
||||
else:
|
||||
get_logger().warning(f"Empty PR diff")
|
||||
self.prediction = None
|
||||
|
||||
data = self.prediction
|
||||
return data
|
||||
|
||||
async def _get_prediction(self, model: str, patches_diff: str, patches_diff_no_line_number: str) -> dict:
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = patches_diff # update diff
|
||||
variables["diff_no_line_numbers"] = patches_diff_no_line_number # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(self.pr_code_suggestions_prompt_system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
||||
if not get_settings().config.publish_output:
|
||||
get_settings().system_prompt = system_prompt
|
||||
get_settings().user_prompt = user_prompt
|
||||
|
||||
# load suggestions from the AI response
|
||||
data = self._prepare_pr_code_suggestions(response)
|
||||
|
||||
# self-reflect on suggestions (mandatory, since line numbers are generated now here)
|
||||
model_reflection = get_settings().config.model
|
||||
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"],
|
||||
patches_diff, model=model_reflection)
|
||||
if response_reflect:
|
||||
await self.analyze_self_reflection_response(data, response_reflect)
|
||||
else:
|
||||
# get_logger().error(f"Could not self-reflect on suggestions. using default score 7")
|
||||
for i, suggestion in enumerate(data["code_suggestions"]):
|
||||
suggestion["score"] = 7
|
||||
suggestion["score_why"] = ""
|
||||
|
||||
return data
|
||||
|
||||
async def analyze_self_reflection_response(self, data, response_reflect):
|
||||
response_reflect_yaml = load_yaml(response_reflect)
|
||||
code_suggestions_feedback = response_reflect_yaml.get("code_suggestions", [])
|
||||
if code_suggestions_feedback and len(code_suggestions_feedback) == len(data["code_suggestions"]):
|
||||
for i, suggestion in enumerate(data["code_suggestions"]):
|
||||
try:
|
||||
suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
|
||||
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
|
||||
|
||||
if 'relevant_lines_start' not in suggestion:
|
||||
relevant_lines_start = code_suggestions_feedback[i].get('relevant_lines_start', -1)
|
||||
relevant_lines_end = code_suggestions_feedback[i].get('relevant_lines_end', -1)
|
||||
suggestion['relevant_lines_start'] = relevant_lines_start
|
||||
suggestion['relevant_lines_end'] = relevant_lines_end
|
||||
if relevant_lines_start < 0 or relevant_lines_end < 0:
|
||||
suggestion["score"] = 0
|
||||
|
||||
try:
|
||||
if get_settings().config.publish_output:
|
||||
if not suggestion["score"]:
|
||||
score = -1
|
||||
else:
|
||||
score = int(suggestion["score"])
|
||||
label = suggestion["label"].lower().strip()
|
||||
label = label.replace('<br>', ' ')
|
||||
suggestion_statistics_dict = {'score': score,
|
||||
'label': label}
|
||||
get_logger().info(f"PR-Agent suggestions statistics",
|
||||
statistics=suggestion_statistics_dict, analytics=True)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to log suggestion statistics, error: {e}")
|
||||
pass
|
||||
|
||||
except Exception as e: #
|
||||
get_logger().error(f"Error processing suggestion score {i}",
|
||||
artifact={"suggestion": suggestion,
|
||||
"code_suggestions_feedback": code_suggestions_feedback[i]})
|
||||
suggestion["score"] = 7
|
||||
suggestion["score_why"] = ""
|
||||
|
||||
# if the before and after code is the same, clear one of them
|
||||
try:
|
||||
if suggestion['existing_code'] == suggestion['improved_code']:
|
||||
get_logger().debug(
|
||||
f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
|
||||
if get_settings().pr_code_suggestions.commitable_code_suggestions:
|
||||
suggestion['improved_code'] = "" # we need 'existing_code' to locate the code in the PR
|
||||
else:
|
||||
suggestion['existing_code'] = ""
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing suggestion {i + 1}, error: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _truncate_if_needed(suggestion):
|
||||
max_code_suggestion_length = get_settings().get("PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0)
|
||||
suggestion_truncation_message = get_settings().get("PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "")
|
||||
if max_code_suggestion_length > 0:
|
||||
if len(suggestion['improved_code']) > max_code_suggestion_length:
|
||||
get_logger().info(f"Truncated suggestion from {len(suggestion['improved_code'])} "
|
||||
f"characters to {max_code_suggestion_length} characters")
|
||||
suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length]
|
||||
suggestion['improved_code'] += f"\n{suggestion_truncation_message}"
|
||||
return suggestion
|
||||
|
||||
def _prepare_pr_code_suggestions(self, predictions: str) -> Dict:
|
||||
data = load_yaml(predictions.strip(),
|
||||
keys_fix_yaml=["relevant_file", "suggestion_content", "existing_code", "improved_code"],
|
||||
first_key="code_suggestions", last_key="label")
|
||||
if isinstance(data, list):
|
||||
data = {'code_suggestions': data}
|
||||
|
||||
# remove or edit invalid suggestions
|
||||
suggestion_list = []
|
||||
one_sentence_summary_list = []
|
||||
for i, suggestion in enumerate(data['code_suggestions']):
|
||||
try:
|
||||
needed_keys = ['one_sentence_summary', 'label', 'relevant_file']
|
||||
is_valid_keys = True
|
||||
for key in needed_keys:
|
||||
if key not in suggestion:
|
||||
is_valid_keys = False
|
||||
get_logger().debug(
|
||||
f"Skipping suggestion {i + 1}, because it does not contain '{key}':\n'{suggestion}")
|
||||
break
|
||||
if not is_valid_keys:
|
||||
continue
|
||||
|
||||
if get_settings().get("pr_code_suggestions.focus_only_on_problems", False):
|
||||
CRITICAL_LABEL = 'critical'
|
||||
if CRITICAL_LABEL in suggestion['label'].lower(): # we want the published labels to be less declarative
|
||||
suggestion['label'] = 'possible issue'
|
||||
|
||||
if suggestion['one_sentence_summary'] in one_sentence_summary_list:
|
||||
get_logger().debug(f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}")
|
||||
continue
|
||||
|
||||
if 'const' in suggestion['suggestion_content'] and 'instead' in suggestion[
|
||||
'suggestion_content'] and 'let' in suggestion['suggestion_content']:
|
||||
get_logger().debug(
|
||||
f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}")
|
||||
continue
|
||||
|
||||
if ('existing_code' in suggestion) and ('improved_code' in suggestion):
|
||||
suggestion = self._truncate_if_needed(suggestion)
|
||||
one_sentence_summary_list.append(suggestion['one_sentence_summary'])
|
||||
suggestion_list.append(suggestion)
|
||||
else:
|
||||
get_logger().info(
|
||||
f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}")
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing suggestion {i + 1}: {suggestion}, error: {e}")
|
||||
data['code_suggestions'] = suggestion_list
|
||||
|
||||
return data
|
||||
|
||||
async def push_inline_code_suggestions(self, data):
|
||||
code_suggestions = []
|
||||
|
||||
if not data['code_suggestions']:
|
||||
get_logger().info('No suggestions found to improve this PR.')
|
||||
if self.progress_response:
|
||||
return self.git_provider.edit_comment(self.progress_response,
|
||||
body='No suggestions found to improve this PR.')
|
||||
else:
|
||||
return self.git_provider.publish_comment('No suggestions found to improve this PR.')
|
||||
|
||||
for d in data['code_suggestions']:
|
||||
try:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"suggestion: {d}")
|
||||
relevant_file = d['relevant_file'].strip()
|
||||
relevant_lines_start = int(d['relevant_lines_start']) # absolute position
|
||||
relevant_lines_end = int(d['relevant_lines_end'])
|
||||
content = d['suggestion_content'].rstrip()
|
||||
new_code_snippet = d['improved_code'].rstrip()
|
||||
label = d['label'].strip()
|
||||
|
||||
if new_code_snippet:
|
||||
new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet)
|
||||
|
||||
if d.get('score'):
|
||||
body = f"**Suggestion:** {content} [{label}, importance: {d.get('score')}]\n```suggestion\n" + new_code_snippet + "\n```"
|
||||
else:
|
||||
body = f"**Suggestion:** {content} [{label}]\n```suggestion\n" + new_code_snippet + "\n```"
|
||||
code_suggestions.append({'body': body, 'relevant_file': relevant_file,
|
||||
'relevant_lines_start': relevant_lines_start,
|
||||
'relevant_lines_end': relevant_lines_end,
|
||||
'original_suggestion': d})
|
||||
except Exception:
|
||||
get_logger().info(f"Could not parse suggestion: {d}")
|
||||
|
||||
is_successful = self.git_provider.publish_code_suggestions(code_suggestions)
|
||||
if not is_successful:
|
||||
get_logger().info("Failed to publish code suggestions, trying to publish each suggestion separately")
|
||||
for code_suggestion in code_suggestions:
|
||||
self.git_provider.publish_code_suggestions([code_suggestion])
|
||||
|
||||
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
|
||||
try: # dedent code snippet
|
||||
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
|
||||
else self.git_provider.get_diff_files()
|
||||
original_initial_line = None
|
||||
for file in self.diff_files:
|
||||
if file.filename.strip() == relevant_file:
|
||||
if file.head_file:
|
||||
file_lines = file.head_file.splitlines()
|
||||
if relevant_lines_start > len(file_lines):
|
||||
get_logger().warning(
|
||||
"Could not dedent code snippet, because relevant_lines_start is out of range",
|
||||
artifact={'filename': file.filename,
|
||||
'file_content': file.head_file,
|
||||
'relevant_lines_start': relevant_lines_start,
|
||||
'new_code_snippet': new_code_snippet})
|
||||
return new_code_snippet
|
||||
else:
|
||||
original_initial_line = file_lines[relevant_lines_start - 1]
|
||||
else:
|
||||
get_logger().warning("Could not dedent code snippet, because head_file is missing",
|
||||
artifact={'filename': file.filename,
|
||||
'relevant_lines_start': relevant_lines_start,
|
||||
'new_code_snippet': new_code_snippet})
|
||||
return new_code_snippet
|
||||
break
|
||||
if original_initial_line:
|
||||
suggested_initial_line = new_code_snippet.splitlines()[0]
|
||||
original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip())
|
||||
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
|
||||
delta_spaces = original_initial_spaces - suggested_initial_spaces
|
||||
if delta_spaces > 0:
|
||||
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error when dedenting code snippet for file {relevant_file}, error: {e}")
|
||||
|
||||
return new_code_snippet
|
||||
|
||||
def _get_is_extended(self, args: list[str]) -> bool:
|
||||
"""Check if extended mode should be enabled by the `--extended` flag or automatically according to the configuration"""
|
||||
if any(["extended" in arg for arg in args]):
|
||||
get_logger().info("Extended mode is enabled by the `--extended` flag")
|
||||
return True
|
||||
if get_settings().pr_code_suggestions.auto_extended_mode:
|
||||
# get_logger().info("Extended mode is enabled automatically based on the configuration toggle")
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_line_numbers(self, patches_diff_list: List[str]) -> List[str]:
|
||||
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
|
||||
try:
|
||||
self.patches_diff_list_no_line_numbers = []
|
||||
for patches_diff in self.patches_diff_list:
|
||||
patches_diff_lines = patches_diff.splitlines()
|
||||
for i, line in enumerate(patches_diff_lines):
|
||||
if line.strip():
|
||||
if line.isnumeric():
|
||||
patches_diff_lines[i] = ''
|
||||
elif line[0].isdigit():
|
||||
# find the first letter in the line that starts with a valid letter
|
||||
for j, char in enumerate(line):
|
||||
if not char.isdigit():
|
||||
patches_diff_lines[i] = line[j + 1:]
|
||||
break
|
||||
self.patches_diff_list_no_line_numbers.append('\n'.join(patches_diff_lines))
|
||||
return self.patches_diff_list_no_line_numbers
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error removing line numbers from patches_diff_list, error: {e}")
|
||||
return patches_diff_list
|
||||
|
||||
async def _prepare_prediction_extended(self, model: str) -> dict:
|
||||
self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
|
||||
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
|
||||
|
||||
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
|
||||
self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list)
|
||||
|
||||
if self.patches_diff_list:
|
||||
get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}")
|
||||
get_logger().debug(f"PR diff:", artifact=self.patches_diff_list)
|
||||
|
||||
# parallelize calls to AI:
|
||||
if get_settings().pr_code_suggestions.parallel_calls:
|
||||
prediction_list = await asyncio.gather(
|
||||
*[self._get_prediction(model, patches_diff, patches_diff_no_line_numbers) for
|
||||
patches_diff, patches_diff_no_line_numbers in
|
||||
zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers)])
|
||||
self.prediction_list = prediction_list
|
||||
else:
|
||||
prediction_list = []
|
||||
for patches_diff, patches_diff_no_line_numbers in zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers):
|
||||
prediction = await self._get_prediction(model, patches_diff, patches_diff_no_line_numbers)
|
||||
prediction_list.append(prediction)
|
||||
|
||||
data = {"code_suggestions": []}
|
||||
for j, predictions in enumerate(prediction_list): # each call adds an element to the list
|
||||
if "code_suggestions" in predictions:
|
||||
score_threshold = max(1, int(get_settings().pr_code_suggestions.suggestions_score_threshold))
|
||||
for i, prediction in enumerate(predictions["code_suggestions"]):
|
||||
try:
|
||||
score = int(prediction.get("score", 1))
|
||||
if score >= score_threshold:
|
||||
data["code_suggestions"].append(prediction)
|
||||
else:
|
||||
get_logger().info(
|
||||
f"Removing suggestions {i} from call {j}, because score is {score}, and score_threshold is {score_threshold}",
|
||||
artifact=prediction)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error getting PR diff for suggestion {i} in call {j}, error: {e}",
|
||||
artifact={"prediction": prediction})
|
||||
self.data = data
|
||||
else:
|
||||
get_logger().warning(f"Empty PR diff list")
|
||||
self.data = data = None
|
||||
return data
|
||||
|
||||
def generate_summarized_suggestions(self, data: Dict) -> str:
|
||||
try:
|
||||
pr_body = "## PR 代码建议 ✨\n\n"
|
||||
|
||||
if len(data.get('code_suggestions', [])) == 0:
|
||||
pr_body += "No suggestions found to improve this PR."
|
||||
return pr_body
|
||||
|
||||
if get_settings().pr_code_suggestions.enable_intro_text and get_settings().config.is_auto_command:
|
||||
pr_body += "Explore these optional code suggestions:\n\n"
|
||||
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
extension_to_language = {}
|
||||
for language, extensions in language_extension_map_org.items():
|
||||
for ext in extensions:
|
||||
extension_to_language[ext] = language
|
||||
|
||||
pr_body += "<table>"
|
||||
header = f"建议"
|
||||
delta = 66
|
||||
header += " " * delta
|
||||
pr_body += f"""<thead><tr><td><strong>类别</strong></td><td align=left><strong>{header}</strong></td><td align=center><strong>影响</strong></td></tr>"""
|
||||
pr_body += """<tbody>"""
|
||||
suggestions_labels = dict()
|
||||
# add all suggestions related to each label
|
||||
for suggestion in data['code_suggestions']:
|
||||
label = suggestion['label'].strip().strip("'").strip('"')
|
||||
if label not in suggestions_labels:
|
||||
suggestions_labels[label] = []
|
||||
suggestions_labels[label].append(suggestion)
|
||||
|
||||
# sort suggestions_labels by the suggestion with the highest score
|
||||
suggestions_labels = dict(
|
||||
sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
|
||||
# sort the suggestions inside each label group by score
|
||||
for label, suggestions in suggestions_labels.items():
|
||||
suggestions_labels[label] = sorted(suggestions, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
counter_suggestions = 0
|
||||
for label, suggestions in suggestions_labels.items():
|
||||
num_suggestions = len(suggestions)
|
||||
pr_body += f"""<tr><td rowspan={num_suggestions}>{label.capitalize()}</td>\n"""
|
||||
for i, suggestion in enumerate(suggestions):
|
||||
|
||||
relevant_file = suggestion['relevant_file'].strip()
|
||||
relevant_lines_start = int(suggestion['relevant_lines_start'])
|
||||
relevant_lines_end = int(suggestion['relevant_lines_end'])
|
||||
range_str = ""
|
||||
if relevant_lines_start == relevant_lines_end:
|
||||
range_str = f"[{relevant_lines_start}]"
|
||||
else:
|
||||
range_str = f"[{relevant_lines_start}-{relevant_lines_end}]"
|
||||
|
||||
try:
|
||||
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
|
||||
relevant_lines_end)
|
||||
except:
|
||||
code_snippet_link = ""
|
||||
# add html table for each suggestion
|
||||
|
||||
suggestion_content = suggestion['suggestion_content'].rstrip()
|
||||
CHAR_LIMIT_PER_LINE = 84
|
||||
suggestion_content = insert_br_after_x_chars(suggestion_content, CHAR_LIMIT_PER_LINE)
|
||||
# pr_body += f"<tr><td><details><summary>{suggestion_content}</summary>"
|
||||
existing_code = suggestion['existing_code'].rstrip() + "\n"
|
||||
improved_code = suggestion['improved_code'].rstrip() + "\n"
|
||||
|
||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
||||
improved_code.split('\n'), n=999)
|
||||
patch_orig = "\n".join(diff)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
|
||||
example_code = ""
|
||||
example_code += f"```diff\n{patch.rstrip()}\n```\n"
|
||||
if i == 0:
|
||||
pr_body += f"""<td>\n\n"""
|
||||
else:
|
||||
pr_body += f"""<tr><td>\n\n"""
|
||||
suggestion_summary = suggestion['one_sentence_summary'].strip().rstrip('.')
|
||||
if "'<" in suggestion_summary and ">'" in suggestion_summary:
|
||||
# escape the '<' and '>' characters, otherwise they are interpreted as html tags
|
||||
get_logger().info(f"Escaped suggestion summary: {suggestion_summary}")
|
||||
suggestion_summary = suggestion_summary.replace("'<", "`<")
|
||||
suggestion_summary = suggestion_summary.replace(">'", ">`")
|
||||
if '`' in suggestion_summary:
|
||||
suggestion_summary = replace_code_tags(suggestion_summary)
|
||||
|
||||
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
|
||||
pr_body += f"""
|
||||
**{suggestion_content}**
|
||||
|
||||
[{relevant_file} {range_str}]({code_snippet_link})
|
||||
|
||||
{example_code.rstrip()}
|
||||
"""
|
||||
if suggestion.get('score_why'):
|
||||
pr_body += f"<details><summary>严重性 [1-10]: {suggestion['score']}</summary>\n\n"
|
||||
pr_body += f"__\n\nWhy: {suggestion['score_why']}\n\n"
|
||||
pr_body += f"</details>"
|
||||
|
||||
pr_body += f"</details>"
|
||||
|
||||
# # add another column for 'score'
|
||||
score_int = int(suggestion.get('score', 0))
|
||||
score_str = f"{score_int}"
|
||||
if get_settings().pr_code_suggestions.new_score_mechanism:
|
||||
score_str = self.get_score_str(score_int)
|
||||
pr_body += f"</td><td align=center>{score_str}\n\n"
|
||||
|
||||
pr_body += f"</td></tr>"
|
||||
counter_suggestions += 1
|
||||
|
||||
# pr_body += "</details>"
|
||||
# pr_body += """</td></tr>"""
|
||||
pr_body += """</tr></tbody></table>"""
|
||||
return pr_body
|
||||
except Exception as e:
|
||||
get_logger().info(f"Failed to publish summarized code suggestions, error: {e}")
|
||||
return ""
|
||||
|
||||
def get_score_str(self, score: int) -> str:
|
||||
th_high = get_settings().pr_code_suggestions.get('new_score_mechanism_th_high', 9)
|
||||
th_medium = get_settings().pr_code_suggestions.get('new_score_mechanism_th_medium', 7)
|
||||
if score >= th_high:
|
||||
return "高"
|
||||
elif score >= th_medium:
|
||||
return "中"
|
||||
else: # score < 7
|
||||
return "低"
|
||||
|
||||
async def self_reflect_on_suggestions(self,
|
||||
suggestion_list: List,
|
||||
patches_diff: str,
|
||||
model: str,
|
||||
prev_suggestions_str: str = "",
|
||||
dedicated_prompt: str = "") -> str:
|
||||
if not suggestion_list:
|
||||
return ""
|
||||
|
||||
try:
|
||||
suggestion_str = ""
|
||||
for i, suggestion in enumerate(suggestion_list):
|
||||
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
|
||||
|
||||
variables = {'suggestion_list': suggestion_list,
|
||||
'suggestion_str': suggestion_str,
|
||||
"diff": patches_diff,
|
||||
'num_code_suggestions': len(suggestion_list),
|
||||
'prev_suggestions_str': prev_suggestions_str,
|
||||
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
|
||||
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False)}
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
|
||||
if dedicated_prompt:
|
||||
system_prompt_reflect = environment.from_string(
|
||||
get_settings().get(dedicated_prompt).system).render(variables)
|
||||
user_prompt_reflect = environment.from_string(
|
||||
get_settings().get(dedicated_prompt).user).render(variables)
|
||||
else:
|
||||
system_prompt_reflect = environment.from_string(
|
||||
get_settings().pr_code_suggestions_reflect_prompt.system).render(variables)
|
||||
user_prompt_reflect = environment.from_string(
|
||||
get_settings().pr_code_suggestions_reflect_prompt.user).render(variables)
|
||||
|
||||
with get_logger().contextualize(command="self_reflect_on_suggestions"):
|
||||
response_reflect, finish_reason_reflect = await self.ai_handler.chat_completion(model=model,
|
||||
system=system_prompt_reflect,
|
||||
user=user_prompt_reflect)
|
||||
except Exception as e:
|
||||
get_logger().info(f"Could not reflect on suggestions, error: {e}")
|
||||
return ""
|
||||
return response_reflect
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user