代码优化,增强清晰度和可维护性。

This commit is contained in:
张建平 2025-02-27 11:07:34 +08:00
parent 1988a400c9
commit de84796560
67 changed files with 5417 additions and 2190 deletions

1
.gitignore vendored
View File

@ -14,3 +14,4 @@ docs/.cache/
db.sqlite3 db.sqlite3
#pr_agent/ #pr_agent/
static/admin/ static/admin/
config.local.ini

View File

@ -20,9 +20,9 @@ pygithub = "*"
python-gitlab = "*" python-gitlab = "*"
retry = "*" retry = "*"
fastapi = "*" fastapi = "*"
psycopg2-binary = "*"
[dev-packages] [dev-packages]
[requires] [requires]
python_version = "3.12" python_version = "3.12"

106
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "420206f7faa4351eabc368a83deae9b7ed9e50b0975ac63a46d6367e9920848b" "sha256": "497c1ff8497659883faf8dcca407665df1b3a37f67720f64b139f9dec8202892"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -169,20 +169,19 @@
}, },
"boto3": { "boto3": {
"hashes": [ "hashes": [
"sha256:01015b38017876d79efd7273f35d9a4adfba505237159621365bed21b9b65eca", "sha256:e58136d52d79425ce26c3c1578bf94d4b2e91ead55fed9f6950406ee9713e6af"
"sha256:03bd8c93b226f07d944fd6b022e11a307bff94ab6a21d51675d7e3ea81ee8424"
], ],
"index": "pip_conf_index_global", "index": "pip_conf_index_global",
"markers": "python_version >= '3.8'", "markers": "python_version >= '3.8'",
"version": "==1.37.0" "version": "==1.37.2"
}, },
"botocore": { "botocore": {
"hashes": [ "hashes": [
"sha256:b129d091a8360b4152ab65327186bf4e250de827c4a9b7ddf40a72b1acf1f3c1", "sha256:3f460f3c32cd6d747d5897a9cbde011bf1715abc7bf0a6ea6fdb0b812df63287",
"sha256:d01661f38c0edac87424344cdf4169f3ab9bc1bf1b677c8b230d025eb66c54a3" "sha256:5f59b966f3cd0c8055ef6f7c2600f7db5f8218071d992e5f95da3f9156d4370f"
], ],
"markers": "python_version >= '3.8'", "markers": "python_version >= '3.8'",
"version": "==1.37.0" "version": "==1.37.2"
}, },
"certifi": { "certifi": {
"hashes": [ "hashes": [
@ -460,12 +459,12 @@
}, },
"django-import-export": { "django-import-export": {
"hashes": [ "hashes": [
"sha256:317842a64233025a277040129fb6792fc48fd39622c185b70bf8c18c393d708f", "sha256:5514d09636e84e823a42cd5e79292f70f20d6d2feed117a145f5b64a5b44f168",
"sha256:ecb4e6cdb4790d69bce261f9cca1007ca19cb431bb5a950ba907898245c8817b" "sha256:bd3fe0aa15a2bce9de4be1a2f882e2c4539fdbfdfa16f2052c98dd7aec0f085c"
], ],
"index": "pip_conf_index_global", "index": "pip_conf_index_global",
"markers": "python_version >= '3.9'", "markers": "python_version >= '3.9'",
"version": "==4.3.6" "version": "==4.3.7"
}, },
"django-simpleui": { "django-simpleui": {
"hashes": [ "hashes": [
@ -794,12 +793,12 @@
}, },
"litellm": { "litellm": {
"hashes": [ "hashes": [
"sha256:02df5865f98ea9734a4d27ac7c33aad9a45c4015403d5c0797d3292ade3c5cb5", "sha256:eaab989c090ccc094b41c3fdf27d1df7f6fb25e091ab0ce48e0f3079f1e51ff5",
"sha256:d241436ac0edf64ec57fb5686f8d84a25998a7e52213d9063adf87df8432701f" "sha256:ff9137c008cdb421db32defb1fbd1ed546a95167de6d276c61b664582ed4ff60"
], ],
"index": "pip_conf_index_global", "index": "pip_conf_index_global",
"markers": "python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'", "markers": "python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'",
"version": "==1.61.16" "version": "==1.61.17"
}, },
"loguru": { "loguru": {
"hashes": [ "hashes": [
@ -1196,6 +1195,81 @@
"markers": "python_version >= '3.6'", "markers": "python_version >= '3.6'",
"version": "==7.0.0" "version": "==7.0.0"
}, },
"psycopg2-binary": {
"hashes": [
"sha256:04392983d0bb89a8717772a193cfaac58871321e3ec69514e1c4e0d4957b5aff",
"sha256:056470c3dc57904bbf63d6f534988bafc4e970ffd50f6271fc4ee7daad9498a5",
"sha256:0ea8e3d0ae83564f2fc554955d327fa081d065c8ca5cc6d2abb643e2c9c1200f",
"sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5",
"sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0",
"sha256:19721ac03892001ee8fdd11507e6a2e01f4e37014def96379411ca99d78aeb2c",
"sha256:1a6784f0ce3fec4edc64e985865c17778514325074adf5ad8f80636cd029ef7c",
"sha256:2286791ececda3a723d1910441c793be44625d86d1a4e79942751197f4d30341",
"sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f",
"sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7",
"sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d",
"sha256:270934a475a0e4b6925b5f804e3809dd5f90f8613621d062848dd82f9cd62007",
"sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142",
"sha256:2ad26b467a405c798aaa1458ba09d7e2b6e5f96b1ce0ac15d82fd9f95dc38a92",
"sha256:2b3d2491d4d78b6b14f76881905c7a8a8abcf974aad4a8a0b065273a0ed7a2cb",
"sha256:2ce3e21dc3437b1d960521eca599d57408a695a0d3c26797ea0f72e834c7ffe5",
"sha256:30e34c4e97964805f715206c7b789d54a78b70f3ff19fbe590104b71c45600e5",
"sha256:3216ccf953b3f267691c90c6fe742e45d890d8272326b4a8b20850a03d05b7b8",
"sha256:32581b3020c72d7a421009ee1c6bf4a131ef5f0a968fab2e2de0c9d2bb4577f1",
"sha256:35958ec9e46432d9076286dda67942ed6d968b9c3a6a2fd62b48939d1d78bf68",
"sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73",
"sha256:3c18f74eb4386bf35e92ab2354a12c17e5eb4d9798e4c0ad3a00783eae7cd9f1",
"sha256:3c4745a90b78e51d9ba06e2088a2fe0c693ae19cc8cb051ccda44e8df8a6eb53",
"sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d",
"sha256:3e9c76f0ac6f92ecfc79516a8034a544926430f7b080ec5a0537bca389ee0906",
"sha256:48b338f08d93e7be4ab2b5f1dbe69dc5e9ef07170fe1f86514422076d9c010d0",
"sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2",
"sha256:512d29bb12608891e349af6a0cccedce51677725a921c07dba6342beaf576f9a",
"sha256:5a507320c58903967ef7384355a4da7ff3f28132d679aeb23572753cbf2ec10b",
"sha256:5c370b1e4975df846b0277b4deba86419ca77dbc25047f535b0bb03d1a544d44",
"sha256:6b269105e59ac96aba877c1707c600ae55711d9dcd3fc4b5012e4af68e30c648",
"sha256:6d4fa1079cab9018f4d0bd2db307beaa612b0d13ba73b5c6304b9fe2fb441ff7",
"sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f",
"sha256:73aa0e31fa4bb82578f3a6c74a73c273367727de397a7a0f07bd83cbea696baa",
"sha256:7559bce4b505762d737172556a4e6ea8a9998ecac1e39b5233465093e8cee697",
"sha256:79625966e176dc97ddabc142351e0409e28acf4660b88d1cf6adb876d20c490d",
"sha256:7a813c8bdbaaaab1f078014b9b0b13f5de757e2b5d9be6403639b298a04d218b",
"sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526",
"sha256:7f4152f8f76d2023aac16285576a9ecd2b11a9895373a1f10fd9db54b3ff06b4",
"sha256:7f5d859928e635fa3ce3477704acee0f667b3a3d3e4bb109f2b18d4005f38287",
"sha256:851485a42dbb0bdc1edcdabdb8557c09c9655dfa2ca0460ff210522e073e319e",
"sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673",
"sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0",
"sha256:8aabf1c1a04584c168984ac678a668094d831f152859d06e055288fa515e4d30",
"sha256:8aecc5e80c63f7459a1a2ab2c64df952051df196294d9f739933a9f6687e86b3",
"sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e",
"sha256:8de718c0e1c4b982a54b41779667242bc630b2197948405b7bd8ce16bcecac92",
"sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a",
"sha256:b5f86c56eeb91dc3135b3fd8a95dc7ae14c538a2f3ad77a19645cf55bab1799c",
"sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8",
"sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909",
"sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47",
"sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864",
"sha256:d00924255d7fc916ef66e4bf22f354a940c67179ad3fd7067d7a0a9c84d2fbfc",
"sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00",
"sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb",
"sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539",
"sha256:e5720a5d25e3b99cd0dc5c8a440570469ff82659bb09431c1439b92caf184d3b",
"sha256:e8b58f0a96e7a1e341fc894f62c1177a7c83febebb5ff9123b579418fdc8a481",
"sha256:e984839e75e0b60cfe75e351db53d6db750b00de45644c5d1f7ee5d1f34a1ce5",
"sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4",
"sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64",
"sha256:ecced182e935529727401b24d76634a357c71c9275b356efafd8a2a91ec07392",
"sha256:ee0e8c683a7ff25d23b55b11161c2663d4b099770f6085ff0a20d4505778d6b4",
"sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1",
"sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1",
"sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567",
"sha256:ffe8ed017e4ed70f68b7b371d84b7d4a790368db9203dfc2d222febd3a9c8863"
],
"index": "pip_conf_index_global",
"markers": "python_version >= '3.8'",
"version": "==2.9.10"
},
"py": { "py": {
"hashes": [ "hashes": [
"sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719",
@ -1766,11 +1840,11 @@
}, },
"s3transfer": { "s3transfer": {
"hashes": [ "hashes": [
"sha256:3b39185cb72f5acc77db1a58b6e25b977f28d20496b6e58d6813d75f464d632f", "sha256:ca855bdeb885174b5ffa95b9913622459d4ad8e331fc98eb01e6d5eb6a30655d",
"sha256:be6ecb39fadd986ef1701097771f87e4d2f821f27f6071c872143884d2950fbc" "sha256:edae4977e3a122445660c7c114bba949f9d191bae3b34a096f18a1c8c354527a"
], ],
"markers": "python_version >= '3.8'", "markers": "python_version >= '3.8'",
"version": "==0.11.2" "version": "==0.11.3"
}, },
"simplepro": { "simplepro": {
"hashes": [ "hashes": [

View File

@ -34,7 +34,13 @@ class GitConfigAdmin(AjaxAdmin):
class ProjectConfigAdmin(AjaxAdmin): class ProjectConfigAdmin(AjaxAdmin):
"""Admin配置""" """Admin配置"""
list_display = ["project_id", "project_name", "project_secret", "commands", "is_enable"] list_display = [
"project_id",
"project_name",
"project_secret",
"commands",
"is_enable",
]
readonly_fields = ["create_by", "delete_at", "detail"] readonly_fields = ["create_by", "delete_at", "detail"]
top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>' top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>'

View File

@ -16,4 +16,3 @@ class Command(BaseCommand):
print("初始化AI配置已创建") print("初始化AI配置已创建")
else: else:
print("初始化AI配置已存在") print("初始化AI配置已存在")

View File

@ -44,9 +44,7 @@ class GitConfig(BaseModel):
null=True, blank=True, max_length=16, verbose_name="Git名称" null=True, blank=True, max_length=16, verbose_name="Git名称"
) )
git_type = fields.RadioField( git_type = fields.RadioField(
choices=constant.GIT_TYPE, choices=constant.GIT_TYPE, default=0, verbose_name="Git类型"
default=0,
verbose_name="Git类型"
) )
git_url = fields.CharField( git_url = fields.CharField(
null=True, blank=True, max_length=128, verbose_name="Git地址" null=True, blank=True, max_length=128, verbose_name="Git地址"
@ -67,6 +65,7 @@ class ProjectConfig(BaseModel):
""" """
项目配置表 项目配置表
""" """
git_config = fields.ForeignKey( git_config = fields.ForeignKey(
GitConfig, GitConfig,
null=True, null=True,
@ -89,10 +88,7 @@ class ProjectConfig(BaseModel):
max_length=256, max_length=256,
verbose_name="默认命令", verbose_name="默认命令",
) )
is_enable = fields.SwitchField( is_enable = fields.SwitchField(default=True, verbose_name="是否启用")
default=True,
verbose_name="是否启用"
)
class Meta: class Meta:
verbose_name = "项目配置" verbose_name = "项目配置"
@ -106,6 +102,7 @@ class ProjectHistory(BaseModel):
""" """
项目历史表 项目历史表
""" """
project = fields.ForeignKey( project = fields.ForeignKey(
ProjectConfig, ProjectConfig,
null=True, null=True,
@ -128,9 +125,7 @@ class ProjectHistory(BaseModel):
mr_title = fields.CharField( mr_title = fields.CharField(
null=True, blank=True, max_length=256, verbose_name="MR标题" null=True, blank=True, max_length=256, verbose_name="MR标题"
) )
source_data = models.JSONField( source_data = models.JSONField(null=True, blank=True, verbose_name="源数据")
null=True, blank=True, verbose_name="源数据"
)
class Meta: class Meta:
verbose_name = "项目历史" verbose_name = "项目历史"

View File

@ -12,12 +12,7 @@ from utils import constant
def load_project_config( def load_project_config(
git_url, git_url, access_token, project_secret, openai_api_base, openai_key, llm_model
access_token,
project_secret,
openai_api_base,
openai_key,
llm_model
): ):
""" """
加载项目配置 加载项目配置
@ -36,12 +31,11 @@ def load_project_config(
"secret": project_secret, "secret": project_secret,
"openai_api_base": openai_api_base, "openai_api_base": openai_api_base,
"openai_key": openai_key, "openai_key": openai_key,
"llm_model": llm_model "llm_model": llm_model,
} }
class WebHookView(View): class WebHookView(View):
@staticmethod @staticmethod
def select_git_provider(git_type): def select_git_provider(git_type):
""" """
@ -82,7 +76,9 @@ class WebHookView(View):
project_config = provider.get_project_config(project_id=project_id) project_config = provider.get_project_config(project_id=project_id)
# Token 校验 # Token 校验
provider.check_secret(request_headers=headers, project_secret=project_config.get("project_secret")) provider.check_secret(
request_headers=headers, project_secret=project_config.get("project_secret")
)
provider.get_merge_request( provider.get_merge_request(
request_data=json_data, request_data=json_data,
@ -91,11 +87,13 @@ class WebHookView(View):
api_base=project_config.get("api_base"), api_base=project_config.get("api_base"),
api_key=project_config.get("api_key"), api_key=project_config.get("api_key"),
llm_model=project_config.get("llm_model"), llm_model=project_config.get("llm_model"),
project_commands=project_config.get("commands") project_commands=project_config.get("commands"),
) )
# 记录请求日志: 目前仅记录合并日志 # 记录请求日志: 目前仅记录合并日志
if json_data.get('object_kind') == 'merge_request': if json_data.get('object_kind') == 'merge_request':
provider.save_pr_agent_log(request_data=json_data, project_id=project_config.get("project_id")) provider.save_pr_agent_log(
request_data=json_data, project_id=project_config.get("project_id")
)
return JsonResponse(status=200, data={"status": "ignored"}) return JsonResponse(status=200, data={"status": "ignored"})

View File

@ -1,8 +1,4 @@
GIT_TYPE = ( GIT_TYPE = ((0, "gitlab"), (1, "github"), (2, "gitea"))
(0, "gitlab"),
(1, "github"),
(2, "gitea")
)
DEFAULT_COMMANDS = ( DEFAULT_COMMANDS = (
("/review", "/review"), ("/review", "/review"),
@ -10,11 +6,7 @@ DEFAULT_COMMANDS = (
("/improve_code", "/improve_code"), ("/improve_code", "/improve_code"),
) )
UA_TYPE = { UA_TYPE = {"GitLab": "gitlab", "GitHub": "github", "Go-http-client": "gitea"}
"GitLab": "gitlab",
"GitHub": "github",
"Go-http-client": "gitea"
}
def get_git_type_from_ua(ua_value): def get_git_type_from_ua(ua_value):

View File

@ -23,7 +23,7 @@ class GitProvider(ABC):
api_base, api_base,
api_key, api_key,
llm_model, llm_model,
project_commands project_commands,
): ):
pass pass
@ -33,7 +33,6 @@ class GitProvider(ABC):
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
@staticmethod @staticmethod
def check_secret(request_headers, project_secret): def check_secret(request_headers, project_secret):
""" """
@ -79,7 +78,7 @@ class GitLabProvider(GitProvider):
"access_token": git_config.access_token, "access_token": git_config.access_token,
"project_secret": project_config.project_secret, "project_secret": project_config.project_secret,
"commands": project_config.commands.split(","), "commands": project_config.commands.split(","),
"project_id": project_config.id "project_id": project_config.id,
} }
def get_merge_request( def get_merge_request(
@ -124,7 +123,10 @@ class GitLabProvider(GitProvider):
self.run_command(mr_url, project_commands) self.run_command(mr_url, project_commands)
# 数据库留存 # 数据库留存
return JsonResponse(status=200, data={"status": "review started"}) return JsonResponse(status=200, data={"status": "review started"})
return JsonResponse(status=400, data={"error": "Merge request URL not found or action not open"}) return JsonResponse(
status=400,
data={"error": "Merge request URL not found or action not open"},
)
@staticmethod @staticmethod
def save_pr_agent_log(request_data, project_id): def save_pr_agent_log(request_data, project_id):
@ -134,13 +136,19 @@ class GitLabProvider(GitProvider):
:param project_id: :param project_id:
:return: :return:
""" """
if request_data.get('object_attributes', {}).get("source_branch") and request_data.get('object_attributes', {}).get("target_branch"): if request_data.get('object_attributes', {}).get(
"source_branch"
) and request_data.get('object_attributes', {}).get("target_branch"):
models.ProjectHistory.objects.create( models.ProjectHistory.objects.create(
project_id=project_id, project_id=project_id,
project_url=request_data.get("project", {}).get("web_url"), project_url=request_data.get("project", {}).get("web_url"),
mr_url=request_data.get('object_attributes', {}).get("url"), mr_url=request_data.get('object_attributes', {}).get("url"),
source_branch=request_data.get('object_attributes', {}).get("source_branch"), source_branch=request_data.get('object_attributes', {}).get(
target_branch=request_data.get('object_attributes', {}).get("target_branch"), "source_branch"
),
target_branch=request_data.get('object_attributes', {}).get(
"target_branch"
),
mr_title=request_data.get('object_attributes', {}).get("title"), mr_title=request_data.get('object_attributes', {}).get("title"),
source_data=request_data, source_data=request_data,
) )

View File

@ -80,14 +80,20 @@ class PRAgent:
if action == "answer": if action == "answer":
if notify: if notify:
notify() notify()
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run() await PRReviewer(
pr_url, is_answer=True, args=args, ai_handler=self.ai_handler
).run()
elif action == "auto_review": elif action == "auto_review":
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run() await PRReviewer(
pr_url, is_auto=True, args=args, ai_handler=self.ai_handler
).run()
elif action in command2class: elif action in command2class:
if notify: if notify:
notify() notify()
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() await command2class[action](
pr_url, ai_handler=self.ai_handler, args=args
).run()
else: else:
return False return False
return True return True

View File

@ -88,7 +88,7 @@ USER_MESSAGE_ONLY_MODELS = [
"deepseek/deepseek-reasoner", "deepseek/deepseek-reasoner",
"o1-mini", "o1-mini",
"o1-mini-2024-09-12", "o1-mini-2024-09-12",
"o1-preview" "o1-preview",
] ]
NO_SUPPORT_TEMPERATURE_MODELS = [ NO_SUPPORT_TEMPERATURE_MODELS = [
@ -99,5 +99,5 @@ NO_SUPPORT_TEMPERATURE_MODELS = [
"o1-2024-12-17", "o1-2024-12-17",
"o3-mini", "o3-mini",
"o3-mini-2025-01-31", "o3-mini-2025-01-31",
"o1-preview" "o1-preview",
] ]

View File

@ -16,7 +16,14 @@ class BaseAiHandler(ABC):
pass pass
@abstractmethod @abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): async def chat_completion(
self,
model: str,
system: str,
user: str,
temperature: float = 0.2,
img_path: str = None,
):
""" """
This method should be implemented to return a chat completion from the AI model. This method should be implemented to return a chat completion from the AI model.
Args: Args:

View File

@ -34,9 +34,16 @@ class LangChainOpenAIHandler(BaseAiHandler):
""" """
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError), @retry(
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) exceptions=(APIError, Timeout, AttributeError, RateLimitError),
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): tries=OPENAI_RETRIES,
delay=2,
backoff=2,
jitter=(1, 3),
)
async def chat_completion(
self, model: str, system: str, user: str, temperature: float = 0.2
):
try: try:
messages = [SystemMessage(content=system), HumanMessage(content=user)] messages = [SystemMessage(content=system), HumanMessage(content=user)]
@ -45,7 +52,7 @@ class LangChainOpenAIHandler(BaseAiHandler):
finish_reason = "completed" finish_reason = "completed"
return resp.content, finish_reason return resp.content, finish_reason
except (Exception) as e: except Exception as e:
get_logger().error("Unknown error during OpenAI inference: ", e) get_logger().error("Unknown error during OpenAI inference: ", e)
raise e raise e
@ -66,7 +73,10 @@ class LangChainOpenAIHandler(BaseAiHandler):
if openai_api_base is None or len(openai_api_base) == 0: if openai_api_base is None or len(openai_api_base) == 0:
return ChatOpenAI(openai_api_key=get_settings().openai.key) return ChatOpenAI(openai_api_key=get_settings().openai.key)
else: else:
return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base) return ChatOpenAI(
openai_api_key=get_settings().openai.key,
openai_api_base=openai_api_base,
)
except AttributeError as e: except AttributeError as e:
if getattr(e, "name"): if getattr(e, "name"):
raise ValueError(f"OpenAI {e.name} is required") from e raise ValueError(f"OpenAI {e.name} is required") from e

View File

@ -36,9 +36,14 @@ class LiteLLMAIHandler(BaseAiHandler):
elif 'OPENAI_API_KEY' not in os.environ: elif 'OPENAI_API_KEY' not in os.environ:
litellm.api_key = "dummy_key" litellm.api_key = "dummy_key"
if get_settings().get("aws.AWS_ACCESS_KEY_ID"): if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete" assert (
get_settings().aws.AWS_SECRET_ACCESS_KEY
and get_settings().aws.AWS_REGION_NAME
), "AWS credentials are incomplete"
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY os.environ[
"AWS_SECRET_ACCESS_KEY"
] = get_settings().aws.AWS_SECRET_ACCESS_KEY
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
if get_settings().get("litellm.use_client"): if get_settings().get("litellm.use_client"):
litellm_token = get_settings().get("litellm.LITELLM_TOKEN") litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
@ -73,14 +78,19 @@ class LiteLLMAIHandler(BaseAiHandler):
litellm.replicate_key = get_settings().replicate.key litellm.replicate_key = get_settings().replicate.key
if get_settings().get("HUGGINGFACE.KEY", None): if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key litellm.huggingface_key = get_settings().huggingface.key
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model: if (
get_settings().get("HUGGINGFACE.API_BASE", None)
and 'huggingface' in get_settings().config.model
):
litellm.api_base = get_settings().huggingface.api_base litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None): if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None): if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty) self.repetition_penalty = float(
get_settings().huggingface.repetition_penalty
)
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project litellm.vertex_project = get_settings().vertexai.vertex_project
litellm.vertex_location = get_settings().get( litellm.vertex_location = get_settings().get(
@ -89,7 +99,9 @@ class LiteLLMAIHandler(BaseAiHandler):
# Google AI Studio # Google AI Studio
# SEE https://docs.litellm.ai/docs/providers/gemini # SEE https://docs.litellm.ai/docs/providers/gemini
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None): if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key os.environ[
"GEMINI_API_KEY"
] = get_settings().google_ai_studio.gemini_api_key
# Support deepseek models # Support deepseek models
if get_settings().get("DEEPSEEK.KEY", None): if get_settings().get("DEEPSEEK.KEY", None):
@ -140,18 +152,25 @@ class LiteLLMAIHandler(BaseAiHandler):
git_provider = get_settings().config.git_provider git_provider = get_settings().config.git_provider
metadata = dict() metadata = dict()
callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback callbacks = (
litellm.success_callback
+ litellm.failure_callback
+ litellm.service_callback
)
if "langfuse" in callbacks: if "langfuse" in callbacks:
metadata.update({ metadata.update(
{
"trace_name": command, "trace_name": command,
"tags": [git_provider, command, f'version:{get_version()}'], "tags": [git_provider, command, f'version:{get_version()}'],
"trace_metadata": { "trace_metadata": {
"command": command, "command": command,
"pr_url": pr_url, "pr_url": pr_url,
}, },
}) }
)
if "langsmith" in callbacks: if "langsmith" in callbacks:
metadata.update({ metadata.update(
{
"run_name": command, "run_name": command,
"tags": [git_provider, command, f'version:{get_version()}'], "tags": [git_provider, command, f'version:{get_version()}'],
"extra": { "extra": {
@ -160,7 +179,8 @@ class LiteLLMAIHandler(BaseAiHandler):
"pr_url": pr_url, "pr_url": pr_url,
} }
}, },
}) }
)
# Adding the captured logs to the kwargs # Adding the captured logs to the kwargs
kwargs["metadata"] = metadata kwargs["metadata"] = metadata
@ -175,10 +195,19 @@ class LiteLLMAIHandler(BaseAiHandler):
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry( @retry(
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError retry=retry_if_exception_type(
stop=stop_after_attempt(OPENAI_RETRIES) (openai.APIError, openai.APIConnectionError, openai.APITimeoutError)
), # No retry on RateLimitError
stop=stop_after_attempt(OPENAI_RETRIES),
) )
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): async def chat_completion(
self,
model: str,
system: str,
user: str,
temperature: float = 0.2,
img_path: str = None,
):
try: try:
resp, finish_reason = None, None resp, finish_reason = None, None
deployment_id = self.deployment_id deployment_id = self.deployment_id
@ -187,8 +216,12 @@ class LiteLLMAIHandler(BaseAiHandler):
if 'claude' in model and not system: if 'claude' in model and not system:
system = "No system prompt provided" system = "No system prompt provided"
get_logger().warning( get_logger().warning(
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.") "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error."
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] )
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
if img_path: if img_path:
try: try:
@ -201,14 +234,21 @@ class LiteLLMAIHandler(BaseAiHandler):
except Exception as e: except Exception as e:
get_logger().error(f"Error fetching image: {img_path}", e) get_logger().error(f"Error fetching image: {img_path}", e)
return f"Error fetching image: {img_path}", "error" return f"Error fetching image: {img_path}", "error"
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]}, messages[1]["content"] = [
{"type": "image_url", "image_url": {"url": img_path}}] {"type": "text", "text": messages[1]["content"]},
{"type": "image_url", "image_url": {"url": img_path}},
]
# Currently, some models do not support a separate system and user prompts # Currently, some models do not support a separate system and user prompts
if model in self.user_message_only_models or get_settings().config.custom_reasoning_model: if (
model in self.user_message_only_models
or get_settings().config.custom_reasoning_model
):
user = f"{system}\n\n\n{user}" user = f"{system}\n\n\n{user}"
system = "" system = ""
get_logger().info(f"Using model {model}, combining system and user prompts") get_logger().info(
f"Using model {model}, combining system and user prompts"
)
messages = [{"role": "user", "content": user}] messages = [{"role": "user", "content": user}]
kwargs = { kwargs = {
"model": model, "model": model,
@ -227,7 +267,10 @@ class LiteLLMAIHandler(BaseAiHandler):
} }
# Add temperature only if model supports it # Add temperature only if model supports it
if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model: if (
model not in self.no_support_temperature_models
and not get_settings().config.custom_reasoning_model
):
kwargs["temperature"] = temperature kwargs["temperature"] = temperature
if get_settings().litellm.get("enable_callbacks", False): if get_settings().litellm.get("enable_callbacks", False):
@ -235,7 +278,9 @@ class LiteLLMAIHandler(BaseAiHandler):
seed = get_settings().config.get("seed", -1) seed = get_settings().config.get("seed", -1)
if temperature > 0 and seed >= 0: if temperature > 0 and seed >= 0:
raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0") raise ValueError(
f"Seed ({seed}) is not supported with temperature ({temperature}) > 0"
)
elif seed >= 0: elif seed >= 0:
get_logger().info(f"Using fixed seed of {seed}") get_logger().info(f"Using fixed seed of {seed}")
kwargs["seed"] = seed kwargs["seed"] = seed
@ -253,10 +298,10 @@ class LiteLLMAIHandler(BaseAiHandler):
except (openai.APIError, openai.APITimeoutError) as e: except (openai.APIError, openai.APITimeoutError) as e:
get_logger().warning(f"Error during LLM inference: {e}") get_logger().warning(f"Error during LLM inference: {e}")
raise raise
except (openai.RateLimitError) as e: except openai.RateLimitError as e:
get_logger().error(f"Rate limit error during LLM inference: {e}") get_logger().error(f"Rate limit error during LLM inference: {e}")
raise raise
except (Exception) as e: except Exception as e:
get_logger().warning(f"Unknown error during LLM inference: {e}") get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e raise openai.APIError from e
if response is None or len(response["choices"]) == 0: if response is None or len(response["choices"]) == 0:
@ -267,7 +312,9 @@ class LiteLLMAIHandler(BaseAiHandler):
get_logger().debug(f"\nAI response:\n{resp}") get_logger().debug(f"\nAI response:\n{resp}")
# log the full response for debugging # log the full response for debugging
response_log = self.prepare_logs(response, system, user, resp, finish_reason) response_log = self.prepare_logs(
response, system, user, resp, finish_reason
)
get_logger().debug("Full_response", artifact=response_log) get_logger().debug("Full_response", artifact=response_log)
# for CLI debugging # for CLI debugging

View File

@ -37,13 +37,23 @@ class OpenAIHandler(BaseAiHandler):
""" """
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError), @retry(
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) exceptions=(APIError, Timeout, AttributeError, RateLimitError),
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): tries=OPENAI_RETRIES,
delay=2,
backoff=2,
jitter=(1, 3),
)
async def chat_completion(
self, model: str, system: str, user: str, temperature: float = 0.2
):
try: try:
get_logger().info("System: ", system) get_logger().info("System: ", system)
get_logger().info("User: ", user) get_logger().info("User: ", user)
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
client = AsyncOpenAI() client = AsyncOpenAI()
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model, model=model,
@ -53,15 +63,21 @@ class OpenAIHandler(BaseAiHandler):
resp = chat_completion.choices[0].message.content resp = chat_completion.choices[0].message.content
finish_reason = chat_completion.choices[0].finish_reason finish_reason = chat_completion.choices[0].finish_reason
usage = chat_completion.usage usage = chat_completion.usage
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, get_logger().info(
model=model, usage=usage) "AI response",
response=resp,
messages=messages,
finish_reason=finish_reason,
model=model,
usage=usage,
)
return resp, finish_reason return resp, finish_reason
except (APIError, Timeout) as e: except (APIError, Timeout) as e:
get_logger().error("Error during OpenAI inference: ", e) get_logger().error("Error during OpenAI inference: ", e)
raise raise
except (RateLimitError) as e: except RateLimitError as e:
get_logger().error("Rate limit error during OpenAI inference: ", e) get_logger().error("Rate limit error during OpenAI inference: ", e)
raise raise
except (Exception) as e: except Exception as e:
get_logger().error("Unknown error during OpenAI inference: ", e) get_logger().error("Unknown error during OpenAI inference: ", e)
raise raise

View File

@ -1,6 +1,7 @@
from base64 import b64decode from base64 import b64decode
import hashlib import hashlib
class CliArgs: class CliArgs:
@staticmethod @staticmethod
def validate_user_args(args: list) -> (bool, str): def validate_user_args(args: list) -> (bool, str):
@ -23,12 +24,12 @@ class CliArgs:
for arg in args: for arg in args:
if arg.startswith('--'): if arg.startswith('--'):
arg_word = arg.lower() arg_word = arg.lower()
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key arg_word = arg_word.replace(
'__', '.'
) # replace double underscore with dot, e.g. --openai__key -> --openai.key
for forbidden_arg_word in forbidden_cli_args: for forbidden_arg_word in forbidden_cli_args:
if forbidden_arg_word in arg_word: if forbidden_arg_word in arg_word:
return False, forbidden_arg_word return False, forbidden_arg_word
return True, "" return True, ""
except Exception as e: except Exception as e:
return False, str(e) return False, str(e)

View File

@ -15,7 +15,9 @@ def filter_ignored(files, platform = 'github'):
if isinstance(patterns, str): if isinstance(patterns, str):
patterns = [patterns] patterns = [patterns]
glob_setting = get_settings().ignore.glob glob_setting = get_settings().ignore.glob
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py if isinstance(
glob_setting, str
): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
glob_setting = glob_setting.strip('[]').split(",") glob_setting = glob_setting.strip('[]').split(",")
patterns += [fnmatch.translate(glob) for glob in glob_setting] patterns += [fnmatch.translate(glob) for glob in glob_setting]
@ -31,7 +33,9 @@ def filter_ignored(files, platform = 'github'):
if files and isinstance(files, list): if files and isinstance(files, list):
for r in compiled_patterns: for r in compiled_patterns:
if platform == 'github': if platform == 'github':
files = [f for f in files if (f.filename and not r.match(f.filename))] files = [
f for f in files if (f.filename and not r.match(f.filename))
]
elif platform == 'bitbucket': elif platform == 'bitbucket':
# files = [f for f in files if (f.new.path and not r.match(f.new.path))] # files = [f for f in files if (f.new.path and not r.match(f.new.path))]
files_o = [] files_o = []
@ -49,10 +53,18 @@ def filter_ignored(files, platform = 'github'):
# files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))] # files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))]
files_o = [] files_o = []
for f in files: for f in files:
if 'new_path' in f and f['new_path'] and not r.match(f['new_path']): if (
'new_path' in f
and f['new_path']
and not r.match(f['new_path'])
):
files_o.append(f) files_o.append(f)
continue continue
if 'old_path' in f and f['old_path'] and not r.match(f['old_path']): if (
'old_path' in f
and f['old_path']
and not r.match(f['old_path'])
):
files_o.append(f) files_o.append(f)
continue continue
files = files_o files = files_o

View File

@ -8,9 +8,18 @@ from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.log import get_logger from utils.pr_agent.log import get_logger
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, def extend_patch(
patch_extra_lines_after=0, filename: str = "") -> str: original_file_str,
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str: patch_str,
patch_extra_lines_before=0,
patch_extra_lines_after=0,
filename: str = "",
) -> str:
if (
not patch_str
or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0)
or not original_file_str
):
return patch_str return patch_str
original_file_str = decode_if_bytes(original_file_str) original_file_str = decode_if_bytes(original_file_str)
@ -21,10 +30,17 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
return patch_str return patch_str
try: try:
extended_patch_str = process_patch_lines(patch_str, original_file_str, extended_patch_str = process_patch_lines(
patch_extra_lines_before, patch_extra_lines_after) patch_str,
original_file_str,
patch_extra_lines_before,
patch_extra_lines_after,
)
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()}) get_logger().warning(
f"Failed to extend patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return patch_str return patch_str
return extended_patch_str return extended_patch_str
@ -48,13 +64,19 @@ def decode_if_bytes(original_file_str):
def should_skip_patch(filename): def should_skip_patch(filename):
patch_extension_skip_types = get_settings().config.patch_extension_skip_types patch_extension_skip_types = get_settings().config.patch_extension_skip_types
if patch_extension_skip_types and filename: if patch_extension_skip_types and filename:
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types) return any(
filename.endswith(skip_type) for skip_type in patch_extension_skip_types
)
return False return False
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after): def process_patch_lines(
patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after
):
allow_dynamic_context = get_settings().config.allow_dynamic_context allow_dynamic_context = get_settings().config.allow_dynamic_context
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context patch_extra_lines_before_dynamic = (
get_settings().config.max_extra_lines_before_dynamic_context
)
original_lines = original_file_str.splitlines() original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines) len_original_lines = len(original_lines)
@ -63,8 +85,7 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
is_valid_hunk = True is_valid_hunk = True
start1, size1, start2, size2 = -1, -1, -1, -1 start1, size1, start2, size2 = -1, -1, -1, -1
RE_HUNK_HEADER = re.compile( RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
try: try:
for i, line in enumerate(patch_lines): for i, line in enumerate(patch_lines):
if line.startswith('@@'): if line.startswith('@@'):
@ -73,49 +94,113 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
if match: if match:
# finish processing previous hunk # finish processing previous hunk
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0): if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]] delta_lines = [
f' {line}'
for line in original_lines[
start1
+ size1
- 1 : start1
+ size1
- 1
+ patch_extra_lines_after
]
]
extended_patch_lines.extend(delta_lines) extended_patch_lines.extend(delta_lines)
section_header, size1, size2, start1, start2 = extract_hunk_headers(match) section_header, size1, size2, start1, start2 = extract_hunk_headers(
match
)
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1) is_valid_hunk = check_if_hunk_lines_matches_to_file(
i, original_lines, patch_lines, start1
)
if is_valid_hunk and (
patch_extra_lines_before > 0 or patch_extra_lines_after > 0
):
if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0):
def _calc_context_limits(patch_lines_before): def _calc_context_limits(patch_lines_before):
extended_start1 = max(1, start1 - patch_lines_before) extended_start1 = max(1, start1 - patch_lines_before)
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after extended_size1 = (
size1
+ (start1 - extended_start1)
+ patch_extra_lines_after
)
extended_start2 = max(1, start2 - patch_lines_before) extended_start2 = max(1, start2 - patch_lines_before)
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after extended_size2 = (
if extended_start1 - 1 + extended_size1 > len_original_lines: size2
+ (start2 - extended_start2)
+ patch_extra_lines_after
)
if (
extended_start1 - 1 + extended_size1
> len_original_lines
):
# we cannot extend beyond the original file # we cannot extend beyond the original file
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines delta_cap = (
extended_start1
- 1
+ extended_size1
- len_original_lines
)
extended_size1 = max(extended_size1 - delta_cap, size1) extended_size1 = max(extended_size1 - delta_cap, size1)
extended_size2 = max(extended_size2 - delta_cap, size2) extended_size2 = max(extended_size2 - delta_cap, size2)
return extended_start1, extended_size1, extended_start2, extended_size2 return (
extended_start1,
extended_size1,
extended_start2,
extended_size2,
)
if allow_dynamic_context: if allow_dynamic_context:
extended_start1, extended_size1, extended_start2, extended_size2 = \ (
_calc_context_limits(patch_extra_lines_before_dynamic) extended_start1,
lines_before = original_lines[extended_start1 - 1:start1 - 1] extended_size1,
extended_start2,
extended_size2,
) = _calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[
extended_start1 - 1 : start1 - 1
]
found_header = False found_header = False
for i, line, in enumerate(lines_before): for (
i,
line,
) in enumerate(lines_before):
if section_header in line: if section_header in line:
found_header = True found_header = True
# Update start and size in one line each # Update start and size in one line each
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i extended_start1, extended_start2 = (
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i extended_start1 + i,
extended_start2 + i,
)
extended_size1, extended_size2 = (
extended_size1 - i,
extended_size2 - i,
)
# get_logger().debug(f"Found section header in line {i} before the hunk") # get_logger().debug(f"Found section header in line {i} before the hunk")
section_header = '' section_header = ''
break break
if not found_header: if not found_header:
# get_logger().debug(f"Section header not found in the extra lines before the hunk") # get_logger().debug(f"Section header not found in the extra lines before the hunk")
extended_start1, extended_size1, extended_start2, extended_size2 = \ (
_calc_context_limits(patch_extra_lines_before) extended_start1,
extended_size1,
extended_start2,
extended_size2,
) = _calc_context_limits(patch_extra_lines_before)
else: else:
extended_start1, extended_size1, extended_start2, extended_size2 = \ (
_calc_context_limits(patch_extra_lines_before) extended_start1,
extended_size1,
extended_start2,
extended_size2,
) = _calc_context_limits(patch_extra_lines_before)
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]] delta_lines = [
f' {line}'
for line in original_lines[extended_start1 - 1 : start1 - 1]
]
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done) # logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context: if section_header and not allow_dynamic_context:
@ -132,17 +217,23 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
extended_patch_lines.append('') extended_patch_lines.append('')
extended_patch_lines.append( extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} ' f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}') f'+{extended_start2},{extended_size2} @@ {section_header}'
)
extended_patch_lines.extend(delta_lines) # one to zero based extended_patch_lines.extend(delta_lines) # one to zero based
continue continue
extended_patch_lines.append(line) extended_patch_lines.append(line)
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()}) get_logger().warning(
f"Failed to extend patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return patch_str return patch_str
# finish processing last hunk # finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk: if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] delta_lines = original_lines[
start1 + size1 - 1 : start1 + size1 - 1 + patch_extra_lines_after
]
# add space at the beginning of each extra line # add space at the beginning of each extra line
delta_lines = [f' {line}' for line in delta_lines] delta_lines = [f' {line}' for line in delta_lines]
extended_patch_lines.extend(delta_lines) extended_patch_lines.extend(delta_lines)
@ -158,11 +249,14 @@ def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
""" """
is_valid_hunk = True is_valid_hunk = True
try: try:
if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file if (
i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' '
): # an existing line in the file
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip(): if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
is_valid_hunk = False is_valid_hunk = False
get_logger().error( get_logger().error(
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content") f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content"
)
except: except:
pass pass
return is_valid_hunk return is_valid_hunk
@ -195,8 +289,7 @@ def omit_deletion_hunks(patch_lines) -> str:
added_patched = [] added_patched = []
add_hunk = False add_hunk = False
inside_hunk = False inside_hunk = False
RE_HUNK_HEADER = re.compile( RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
for line in patch_lines: for line in patch_lines:
if line.startswith('@@'): if line.startswith('@@'):
@ -221,8 +314,13 @@ def omit_deletion_hunks(patch_lines) -> str:
return '\n'.join(added_patched) return '\n'.join(added_patched)
def handle_patch_deletions(patch: str, original_file_content_str: str, def handle_patch_deletions(
new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str: patch: str,
original_file_content_str: str,
new_file_content_str: str,
file_name: str,
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN,
) -> str:
""" """
Handle entire file or deletion patches. Handle entire file or deletion patches.
@ -239,7 +337,9 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
str: The modified patch with deletion hunks omitted. str: The modified patch with deletion hunks omitted.
""" """
if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN): if not new_file_content_str and (
edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN
):
# logic for handling deleted files - don't show patch, just show that the file was deleted # logic for handling deleted files - don't show patch, just show that the file was deleted
if get_settings().config.verbosity_level > 0: if get_settings().config.verbosity_level > 0:
get_logger().info(f"Processing file: {file_name}, minimizing deletion file") get_logger().info(f"Processing file: {file_name}, minimizing deletion file")
@ -292,8 +392,7 @@ __old hunk__
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n" patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
patch_lines = patch.splitlines() patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile( RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
new_content_lines = [] new_content_lines = []
old_content_lines = [] old_content_lines = []
match = None match = None
@ -307,20 +406,32 @@ __old hunk__
if line.startswith('@@'): if line.startswith('@@'):
header_line = line header_line = line
match = RE_HUNK_HEADER.match(line) match = RE_HUNK_HEADER.match(line)
if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines if match and (
new_content_lines or old_content_lines
): # found a new hunk, split the previous lines
if prev_header_line: if prev_header_line:
patch_with_lines_str += f'\n{prev_header_line}\n' patch_with_lines_str += f'\n{prev_header_line}\n'
is_plus_lines = is_minus_lines = False is_plus_lines = is_minus_lines = False
if new_content_lines: if new_content_lines:
is_plus_lines = any([line.startswith('+') for line in new_content_lines]) is_plus_lines = any(
[line.startswith('+') for line in new_content_lines]
)
if old_content_lines: if old_content_lines:
is_minus_lines = any([line.startswith('-') for line in old_content_lines]) is_minus_lines = any(
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused [line.startswith('-') for line in old_content_lines]
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n' )
if (
is_plus_lines or is_minus_lines
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
patch_with_lines_str = (
patch_with_lines_str.rstrip() + '\n__new hunk__\n'
)
for i, line_new in enumerate(new_content_lines): for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n" patch_with_lines_str += f"{start2 + i} {line_new}\n"
if is_minus_lines: if is_minus_lines:
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n' patch_with_lines_str = (
patch_with_lines_str.rstrip() + '\n__old hunk__\n'
)
for line_old in old_content_lines: for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n" patch_with_lines_str += f"{line_old}\n"
new_content_lines = [] new_content_lines = []
@ -335,8 +446,12 @@ __old hunk__
elif line.startswith('-'): elif line.startswith('-'):
old_content_lines.append(line) old_content_lines.append(line)
else: else:
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it if (
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'): not line and line_i
): # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith(
'@@'
):
continue continue
elif line_i + 1 == len(patch_lines): elif line_i + 1 == len(patch_lines):
continue continue
@ -351,7 +466,9 @@ __old hunk__
is_plus_lines = any([line.startswith('+') for line in new_content_lines]) is_plus_lines = any([line.startswith('+') for line in new_content_lines])
if old_content_lines: if old_content_lines:
is_minus_lines = any([line.startswith('-') for line in old_content_lines]) is_minus_lines = any([line.startswith('-') for line in old_content_lines])
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused if (
is_plus_lines or is_minus_lines
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n' patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
for i, line_new in enumerate(new_content_lines): for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n" patch_with_lines_str += f"{start2 + i} {line_new}\n"
@ -363,13 +480,16 @@ __old hunk__
return patch_with_lines_str.rstrip() return patch_with_lines_str.rstrip()
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]: def extract_hunk_lines_from_patch(
patch: str, file_name, line_start, line_end, side
) -> tuple[str, str]:
try: try:
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n" patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
selected_lines = "" selected_lines = ""
patch_lines = patch.splitlines() patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile( RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
)
match = None match = None
start1, size1, start2, size2 = -1, -1, -1, -1 start1, size1, start2, size2 = -1, -1, -1, -1
skip_hunk = False skip_hunk = False
@ -385,7 +505,9 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
match = RE_HUNK_HEADER.match(line) match = RE_HUNK_HEADER.match(line)
section_header, size1, size2, start1, start2 = extract_hunk_headers(match) section_header, size1, size2, start1, start2 = extract_hunk_headers(
match
)
# check if line range is in this hunk # check if line range is in this hunk
if side.lower() == 'left': if side.lower() == 'left':
@ -400,15 +522,26 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
patch_with_lines_str += f'\n{header_line}\n' patch_with_lines_str += f'\n{header_line}\n'
elif not skip_hunk: elif not skip_hunk:
if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end: if (
side.lower() == 'right'
and line_start <= start2 + selected_lines_num <= line_end
):
selected_lines += line + '\n' selected_lines += line + '\n'
if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end: if (
side.lower() == 'left'
and start1 <= selected_lines_num + start1 <= line_end
):
selected_lines += line + '\n' selected_lines += line + '\n'
patch_with_lines_str += line + '\n' patch_with_lines_str += line + '\n'
if not line.startswith('-'): # currently we don't support /ask line for deleted lines if not line.startswith(
'-'
): # currently we don't support /ask line for deleted lines
selected_lines_num += 1 selected_lines_num += 1
except Exception as e: except Exception as e:
get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()}) get_logger().error(
f"Failed to extract hunk lines from patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return "", "" return "", ""
return patch_with_lines_str.rstrip(), selected_lines.rstrip() return patch_with_lines_str.rstrip(), selected_lines.rstrip()

View File

@ -9,7 +9,11 @@ def filter_bad_extensions(files):
bad_extensions = get_settings().bad_extensions.default bad_extensions = get_settings().bad_extensions.default
if get_settings().config.use_extra_bad_extensions: if get_settings().config.use_extra_bad_extensions:
bad_extensions += get_settings().bad_extensions.extra bad_extensions += get_settings().bad_extensions.extra
return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)] return [
f
for f in files
if f.filename is not None and is_valid_file(f.filename, bad_extensions)
]
def is_valid_file(filename: str, bad_extensions=None) -> bool: def is_valid_file(filename: str, bad_extensions=None) -> bool:
@ -27,12 +31,16 @@ def sort_files_by_main_languages(languages: Dict, files: list):
Sort files by their main language, put the files that are in the main language first and the rest files after Sort files by their main language, put the files that are in the main language first and the rest files after
""" """
# sort languages by their size # sort languages by their size
languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)] languages_sorted_list = [
k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)
]
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True) # languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
# get all extensions for the languages # get all extensions for the languages
main_extensions = [] main_extensions = []
language_extension_map_org = get_settings().language_extension_map_org language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} language_extension_map = {
k.lower(): v for k, v in language_extension_map_org.items()
}
for language in languages_sorted_list: for language in languages_sorted_list:
if language.lower() in language_extension_map: if language.lower() in language_extension_map:
main_extensions.append(language_extension_map[language.lower()]) main_extensions.append(language_extension_map[language.lower()])
@ -62,7 +70,9 @@ def sort_files_by_main_languages(languages: Dict, files: list):
if extension_str in extensions: if extension_str in extensions:
tmp.append(file) tmp.append(file)
else: else:
if (file.filename not in rest_files) and (extension_str not in main_extensions_flat): if (file.filename not in rest_files) and (
extension_str not in main_extensions_flat
):
rest_files[file.filename] = file rest_files[file.filename] = file
if len(tmp) > 0: if len(tmp) > 0:
files_sorted.append({"language": lang, "files": tmp}) files_sorted.append({"language": lang, "files": tmp})

View File

@ -7,18 +7,28 @@ from github import RateLimitExceededException
from utils.pr_agent.algo.file_filter import filter_ignored from utils.pr_agent.algo.file_filter import filter_ignored
from utils.pr_agent.algo.git_patch_processing import ( from utils.pr_agent.algo.git_patch_processing import (
convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions) convert_to_hunks_with_lines_numbers,
extend_patch,
handle_patch_deletions,
)
from utils.pr_agent.algo.language_handler import sort_files_by_main_languages from utils.pr_agent.algo.language_handler import sort_files_by_main_languages
from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.types import EDIT_TYPE from utils.pr_agent.algo.types import EDIT_TYPE
from utils.pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model from utils.pr_agent.algo.utils import (
ModelType,
clip_tokens,
get_max_tokens,
get_weak_model,
)
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.git_provider import GitProvider from utils.pr_agent.git_providers.git_provider import GitProvider
from utils.pr_agent.log import get_logger from utils.pr_agent.log import get_logger
DELETED_FILES_ = "Deleted files:\n" DELETED_FILES_ = "Deleted files:\n"
MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n" MORE_MODIFIED_FILES_ = (
"Additional modified files (insufficient token budget to process):\n"
)
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n" ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
@ -29,45 +39,59 @@ MAX_EXTRA_LINES = 10
def cap_and_log_extra_lines(value, direction) -> int: def cap_and_log_extra_lines(value, direction) -> int:
if value > MAX_EXTRA_LINES: if value > MAX_EXTRA_LINES:
get_logger().warning(f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}") get_logger().warning(
f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}"
)
return MAX_EXTRA_LINES return MAX_EXTRA_LINES
return value return value
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, def get_pr_diff(
git_provider: GitProvider,
token_handler: TokenHandler,
model: str, model: str,
add_line_numbers_to_hunks: bool = False, add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False, disable_extra_lines: bool = False,
large_pr_handling=False, large_pr_handling=False,
return_remaining_files=False): return_remaining_files=False,
):
if disable_extra_lines: if disable_extra_lines:
PATCH_EXTRA_LINES_BEFORE = 0 PATCH_EXTRA_LINES_BEFORE = 0
PATCH_EXTRA_LINES_AFTER = 0 PATCH_EXTRA_LINES_AFTER = 0
else: else:
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before") PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after") PATCH_EXTRA_LINES_BEFORE, "before"
)
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(
PATCH_EXTRA_LINES_AFTER, "after"
)
try: try:
diff_files_original = git_provider.get_diff_files() diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e: except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}") get_logger().error(
f"Rate limit exceeded for git provider API. original message {e}"
)
raise raise
diff_files = filter_ignored(diff_files_original) diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original: if diff_files != diff_files_original:
try: try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files") get_logger().info(
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
)
new_names = set([a.filename for a in diff_files]) new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original]) orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}") get_logger().info(f"Filtered out files: {orig_names - new_names}")
except Exception as e: except Exception as e:
pass pass
# get pr languages # get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) pr_languages = sort_files_by_main_languages(
git_provider.get_languages(), diff_files
)
if pr_languages: if pr_languages:
try: try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}") get_logger().info(f"PR main language: {pr_languages[0]['language']}")
@ -76,23 +100,41 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
# generate a standard diff string, with patch extension # generate a standard diff string, with patch extension
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks, pr_languages,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER) token_handler,
add_line_numbers_to_hunks,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
)
# if we are under the limit, return the full diff # if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, " get_logger().info(
f"returning full diff.") f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
f"returning full diff."
)
return "\n".join(patches_extended) return "\n".join(patches_extended)
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines) # if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " get_logger().info(
f"pruning diff.") f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ f"pruning diff."
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling) )
(
patches_compressed_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
) = pr_generate_compressed_diff(
pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling
)
if large_pr_handling and len(patches_compressed_list) > 1: if large_pr_handling and len(patches_compressed_list) > 1:
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.") get_logger().info(
f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff."
)
return "" # return empty string, as we want to generate multiple patches with a different prompt return "" # return empty string, as we want to generate multiple patches with a different prompt
# return the first patch # return the first patch
@ -144,26 +186,37 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
if deleted_list_str: if deleted_list_str:
final_diff = final_diff + "\n\n" + deleted_list_str final_diff = final_diff + "\n\n" + deleted_list_str
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, " get_logger().debug(
f"deleted_list_str: {deleted_list_str}") f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}"
)
if not return_remaining_files: if not return_remaining_files:
return final_diff return final_diff
else: else:
return final_diff, remaining_files_list return final_diff, remaining_files_list
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff_multiple_patchs(
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False): git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False,
):
try: try:
diff_files_original = git_provider.get_diff_files() diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e: except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}") get_logger().error(
f"Rate limit exceeded for git provider API. original message {e}"
)
raise raise
diff_files = filter_ignored(diff_files_original) diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original: if diff_files != diff_files_original:
try: try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files") get_logger().info(
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
)
new_names = set([a.filename for a in diff_files]) new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original]) orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}") get_logger().info(f"Filtered out files: {orig_names - new_names}")
@ -171,24 +224,47 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
pass pass
# get pr languages # get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) pr_languages = sort_files_by_main_languages(
git_provider.get_languages(), diff_files
)
if pr_languages: if pr_languages:
try: try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}") get_logger().info(f"PR main language: {pr_languages[0]['language']}")
except Exception as e: except Exception as e:
pass pass
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ (
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True) patches_compressed_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
) = pr_generate_compressed_diff(
pr_languages,
token_handler,
model,
add_line_numbers_to_hunks,
large_pr_handling=True,
)
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list return (
patches_compressed_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
)
def pr_generate_extended_diff(pr_languages: list, def pr_generate_extended_diff(
pr_languages: list,
token_handler: TokenHandler, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool, add_line_numbers_to_hunks: bool,
patch_extra_lines_before: int = 0, patch_extra_lines_before: int = 0,
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]: patch_extra_lines_after: int = 0,
) -> Tuple[list, int, list]:
total_tokens = token_handler.prompt_tokens # initial tokens total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = [] patches_extended = []
patches_extended_tokens = [] patches_extended_tokens = []
@ -200,20 +276,33 @@ def pr_generate_extended_diff(pr_languages: list,
continue continue
# extend each patch with extra lines of context # extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, extended_patch = extend_patch(
patch_extra_lines_before, patch_extra_lines_after, file.filename) original_file_content_str,
patch,
patch_extra_lines_before,
patch_extra_lines_after,
file.filename,
)
if not extended_patch: if not extended_patch:
get_logger().warning(f"Failed to extend patch for file: {file.filename}") get_logger().warning(
f"Failed to extend patch for file: {file.filename}"
)
continue continue
if add_line_numbers_to_hunks: if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file) full_extended_patch = convert_to_hunks_with_lines_numbers(
extended_patch, file
)
else: else:
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n" full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
# add AI-summary metadata to the patch # add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False): if file.ai_file_summary and get_settings().get(
full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch) "config.enable_ai_metadata", False
):
full_extended_patch = add_ai_summary_top_patch(
file, full_extended_patch
)
patch_tokens = token_handler.count_tokens(full_extended_patch) patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens file.tokens = patch_tokens
@ -224,9 +313,13 @@ def pr_generate_extended_diff(pr_languages: list,
return patches_extended, total_tokens, patches_extended_tokens return patches_extended, total_tokens, patches_extended_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, def pr_generate_compressed_diff(
top_langs: list,
token_handler: TokenHandler,
model: str,
convert_hunks_to_line_numbers: bool, convert_hunks_to_line_numbers: bool,
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]: large_pr_handling: bool,
) -> Tuple[list, list, list, list, dict, list]:
deleted_files_list = [] deleted_files_list = []
# sort each one of the languages in top_langs by the number of tokens in the diff # sort each one of the languages in top_langs by the number of tokens in the diff
@ -244,8 +337,13 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
continue continue
# removing delete-only hunks # removing delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, patch = handle_patch_deletions(
new_file_content_str, file.filename, file.edit_type) patch,
original_file_content_str,
new_file_content_str,
file.filename,
file.edit_type,
)
if patch is None: if patch is None:
if file.filename not in deleted_files_list: if file.filename not in deleted_files_list:
deleted_files_list.append(file.filename) deleted_files_list.append(file.filename)
@ -259,7 +357,11 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
# patch = add_ai_summary_top_patch(file, patch) # patch = add_ai_summary_top_patch(file, patch)
new_patch_tokens = token_handler.count_tokens(patch) new_patch_tokens = token_handler.count_tokens(patch)
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type} file_dict[file.filename] = {
'patch': patch,
'tokens': new_patch_tokens,
'edit_type': file.edit_type,
}
max_tokens_model = get_max_tokens(model) max_tokens_model = get_max_tokens(model)
@ -268,21 +370,41 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
remaining_files_list = [file.filename for file in sorted_files] remaining_files_list = [file.filename for file in sorted_files]
patches_list = [] patches_list = []
total_tokens_list = [] total_tokens_list = []
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict, (
max_tokens_model, remaining_files_list, token_handler) total_tokens,
patches,
remaining_files_list,
files_in_patch_list,
) = generate_full_patch(
convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list,
token_handler,
)
patches_list.append(patches) patches_list.append(patches)
total_tokens_list.append(total_tokens) total_tokens_list.append(total_tokens)
files_in_patches_list.append(files_in_patch_list) files_in_patches_list.append(files_in_patch_list)
# additional iterations (if needed) # additional iterations (if needed)
if large_pr_handling: if large_pr_handling:
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize NUMBER_OF_ALLOWED_ITERATIONS = (
get_settings().pr_description.max_ai_calls - 1
) # one more call is to summarize
for i in range(NUMBER_OF_ALLOWED_ITERATIONS - 1): for i in range(NUMBER_OF_ALLOWED_ITERATIONS - 1):
if remaining_files_list: if remaining_files_list:
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, (
total_tokens,
patches,
remaining_files_list,
files_in_patch_list,
) = generate_full_patch(
convert_hunks_to_line_numbers,
file_dict, file_dict,
max_tokens_model, max_tokens_model,
remaining_files_list, token_handler) remaining_files_list,
token_handler,
)
if patches: if patches:
patches_list.append(patches) patches_list.append(patches)
total_tokens_list.append(total_tokens) total_tokens_list.append(total_tokens)
@ -290,10 +412,23 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
else: else:
break break
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list return (
patches_list,
total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
)
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler): def generate_full_patch(
convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list_prev,
token_handler,
):
total_tokens = token_handler.prompt_tokens # initial tokens total_tokens = token_handler.prompt_tokens # initial tokens
patches = [] patches = []
remaining_files_list_new = [] remaining_files_list_new = []
@ -312,7 +447,10 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
continue continue
# If the patch is too large, just show the file name # If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: if (
total_tokens + new_patch_tokens
> max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
# Current logic is to skip the patch if it's too large # Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens # TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements # until we meet the requirements
@ -334,7 +472,9 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
return total_tokens, patches, remaining_files_list_new, files_in_patch_list return total_tokens, patches, remaining_files_list_new, files_in_patch_list
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR): async def retry_with_fallback_models(
f: Callable, model_type: ModelType = ModelType.REGULAR
):
all_models = _get_all_models(model_type) all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models) all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception # try each (model, deployment_id) pair until one is successful, otherwise raise exception
@ -347,11 +487,11 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
get_settings().set("openai.deployment_id", deployment_id) get_settings().set("openai.deployment_id", deployment_id)
return await f(model) return await f(model)
except: except:
get_logger().warning( get_logger().warning(f"Failed to generate prediction with {model}")
f"Failed to generate prediction with {model}"
)
if i == len(all_models) - 1: # If it's the last iteration if i == len(all_models) - 1: # If it's the last iteration
raise Exception(f"Failed to generate prediction with any model of {all_models}") raise Exception(
f"Failed to generate prediction with any model of {all_models}"
)
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]: def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
@ -374,17 +514,21 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
if fallback_deployments: if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments all_deployments = [deployment_id] + fallback_deployments
if len(all_deployments) < len(all_models): if len(all_deployments) < len(all_models):
raise ValueError(f"The number of deployments ({len(all_deployments)}) " raise ValueError(
f"is less than the number of models ({len(all_models)})") f"The number of deployments ({len(all_deployments)}) "
f"is less than the number of models ({len(all_models)})"
)
else: else:
all_deployments = [deployment_id] * len(all_models) all_deployments = [deployment_id] * len(all_models)
return all_deployments return all_deployments
def get_pr_multi_diffs(git_provider: GitProvider, def get_pr_multi_diffs(
git_provider: GitProvider,
token_handler: TokenHandler, token_handler: TokenHandler,
model: str, model: str,
max_calls: int = 5) -> List[str]: max_calls: int = 5,
) -> List[str]:
""" """
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file. Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model. The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
@ -404,13 +548,17 @@ def get_pr_multi_diffs(git_provider: GitProvider,
try: try:
diff_files = git_provider.get_diff_files() diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e: except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}") get_logger().error(
f"Rate limit exceeded for git provider API. original message {e}"
)
raise raise
diff_files = filter_ignored(diff_files) diff_files = filter_ignored(diff_files)
# Sort files by main language # Sort files by main language
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) pr_languages = sort_files_by_main_languages(
git_provider.get_languages(), diff_files
)
# Sort files within each language group by tokens in descending order # Sort files within each language group by tokens in descending order
sorted_files = [] sorted_files = []
@ -420,14 +568,19 @@ def get_pr_multi_diffs(git_provider: GitProvider,
# Get the maximum number of extra lines before and after the patch # Get the maximum number of extra lines before and after the patch
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before") PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
PATCH_EXTRA_LINES_BEFORE, "before"
)
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after") PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
# try first a single run with standard diff string, with patch extension, and no deletions # try first a single run with standard diff string, with patch extension, and no deletions
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=True, pr_languages,
token_handler,
add_line_numbers_to_hunks=True,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER) patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
)
# if we are under the limit, return the full diff # if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
@ -450,27 +603,50 @@ def get_pr_multi_diffs(git_provider: GitProvider,
continue continue
# Remove delete-only hunks # Remove delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type) patch = handle_patch_deletions(
patch,
original_file_content_str,
new_file_content_str,
file.filename,
file.edit_type,
)
if patch is None: if patch is None:
continue continue
patch = convert_to_hunks_with_lines_numbers(patch, file) patch = convert_to_hunks_with_lines_numbers(patch, file)
# add AI-summary metadata to the patch # add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False): if file.ai_file_summary and get_settings().get(
"config.enable_ai_metadata", False
):
patch = add_ai_summary_top_patch(file, patch) patch = add_ai_summary_top_patch(file, patch)
new_patch_tokens = token_handler.count_tokens(patch) new_patch_tokens = token_handler.count_tokens(patch)
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( if (
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: patch
and (token_handler.prompt_tokens + new_patch_tokens)
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
if get_settings().config.get('large_patch_policy', 'skip') == 'skip': if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
get_logger().warning(f"Patch too large, skipping: {file.filename}") get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue continue
elif get_settings().config.get('large_patch_policy') == 'clip': elif get_settings().config.get('large_patch_policy') == 'clip':
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens delta_tokens = (
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens) get_max_tokens(model)
- OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
- token_handler.prompt_tokens
)
patch_clipped = clip_tokens(
patch,
delta_tokens,
delete_last_line=True,
num_input_tokens=new_patch_tokens,
)
new_patch_tokens = token_handler.count_tokens(patch_clipped) new_patch_tokens = token_handler.count_tokens(patch_clipped)
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( if (
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: patch_clipped
and (token_handler.prompt_tokens + new_patch_tokens)
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
get_logger().warning(f"Patch too large, skipping: {file.filename}") get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue continue
else: else:
@ -480,7 +656,10 @@ def get_pr_multi_diffs(git_provider: GitProvider,
get_logger().warning(f"Patch too large, skipping: {file.filename}") get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue continue
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD): if patch and (
total_tokens + new_patch_tokens
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
):
final_diff = "\n".join(patches) final_diff = "\n".join(patches)
final_diff_list.append(final_diff) final_diff_list.append(final_diff)
patches = [] patches = []
@ -497,7 +676,9 @@ def get_pr_multi_diffs(git_provider: GitProvider,
patches.append(patch) patches.append(patch)
total_tokens += new_patch_tokens total_tokens += new_patch_tokens
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}") get_logger().info(
f"Tokens: {total_tokens}, last filename: {file.filename}"
)
# Add the last chunk # Add the last chunk
if patches: if patches:
@ -515,7 +696,10 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
if not pr_description_files: if not pr_description_files:
get_logger().warning(f"PR description files are empty.") get_logger().warning(f"PR description files are empty.")
return return
available_files = {pr_file['full_file_name'].strip(): pr_file for pr_file in pr_description_files} available_files = {
pr_file['full_file_name'].strip(): pr_file
for pr_file in pr_description_files
}
diff_files = git_provider.get_diff_files() diff_files = git_provider.get_diff_files()
found_any_match = False found_any_match = False
for file in diff_files: for file in diff_files:
@ -524,11 +708,15 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
file.ai_file_summary = available_files[filename] file.ai_file_summary = available_files[filename]
found_any_match = True found_any_match = True
if not found_any_match: if not found_any_match:
get_logger().error(f"Failed to find any matching files between PR description and diff files.", get_logger().error(
artifact={"pr_description_files": pr_description_files}) f"Failed to find any matching files between PR description and diff files.",
artifact={"pr_description_files": pr_description_files},
)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to add AI metadata to diff files: {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Failed to add AI metadata to diff files: {e}",
artifact={"traceback": traceback.format_exc()},
)
def add_ai_summary_top_patch(file, full_extended_patch): def add_ai_summary_top_patch(file, full_extended_patch):
@ -537,14 +725,18 @@ def add_ai_summary_top_patch(file, full_extended_patch):
full_extended_patch_lines = full_extended_patch.split("\n") full_extended_patch_lines = full_extended_patch.split("\n")
for i, line in enumerate(full_extended_patch_lines): for i, line in enumerate(full_extended_patch_lines):
if line.startswith("## File:") or line.startswith("## file:"): if line.startswith("## File:") or line.startswith("## file:"):
full_extended_patch_lines.insert(i + 1, full_extended_patch_lines.insert(
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}") i + 1,
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}",
)
full_extended_patch = "\n".join(full_extended_patch_lines) full_extended_patch = "\n".join(full_extended_patch_lines)
return full_extended_patch return full_extended_patch
# if no '## File: ...' was found # if no '## File: ...' was found
return full_extended_patch return full_extended_patch
except Exception as e: except Exception as e:
get_logger().error(f"Failed to add AI summary to the top of the patch: {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Failed to add AI summary to the top of the patch: {e}",
artifact={"traceback": traceback.format_exc()},
)
return full_extended_patch return full_extended_patch

View File

@ -15,12 +15,17 @@ class TokenEncoder:
@classmethod @classmethod
def get_token_encoder(cls): def get_token_encoder(cls):
model = get_settings().config.model model = get_settings().config.model
if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance if (
cls._encoder_instance is None or model != cls._model
): # Check without acquiring the lock for performance
with cls._lock: # Lock acquisition to ensure thread safety with cls._lock: # Lock acquisition to ensure thread safety
if cls._encoder_instance is None or model != cls._model: if cls._encoder_instance is None or model != cls._model:
cls._model = model cls._model = model
cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding( cls._encoder_instance = (
"cl100k_base") encoding_for_model(cls._model)
if "gpt" in cls._model
else get_encoding("cl100k_base")
)
return cls._encoder_instance return cls._encoder_instance
@ -49,7 +54,9 @@ class TokenHandler:
""" """
self.encoder = TokenEncoder.get_token_encoder() self.encoder = TokenEncoder.get_token_encoder()
if pr is not None: if pr is not None:
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(
pr, self.encoder, vars, system, user
)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
""" """

View File

@ -41,10 +41,12 @@ class Range(BaseModel):
column_start: int = -1 column_start: int = -1
column_end: int = -1 column_end: int = -1
class ModelType(str, Enum): class ModelType(str, Enum):
REGULAR = "regular" REGULAR = "regular"
WEAK = "weak" WEAK = "weak"
class PRReviewHeader(str, Enum): class PRReviewHeader(str, Enum):
REGULAR = "## PR 评审指南" REGULAR = "## PR 评审指南"
INCREMENTAL = "## 增量 PR 评审指南" INCREMENTAL = "## 增量 PR 评审指南"
@ -57,7 +59,9 @@ class PRDescriptionHeader(str, Enum):
def get_setting(key: str) -> Any: def get_setting(key: str) -> Any:
try: try:
key = key.upper() key = key.upper()
return context.get("settings", global_settings).get(key, global_settings.get(key, None)) return context.get("settings", global_settings).get(
key, global_settings.get(key, None)
)
except Exception: except Exception:
return global_settings.get(key, None) return global_settings.get(key, None)
@ -72,14 +76,29 @@ def emphasize_header(text: str, only_markdown=False, reference_link=None) -> str
# Everything before the colon (inclusive) is wrapped in <strong> tags # Everything before the colon (inclusive) is wrapped in <strong> tags
if only_markdown: if only_markdown:
if reference_link: if reference_link:
transformed_string = f"[**{text[:colon_position + 1]}**]({reference_link})\n" + text[colon_position + 1:] transformed_string = (
f"[**{text[:colon_position + 1]}**]({reference_link})\n"
+ text[colon_position + 1 :]
)
else: else:
transformed_string = f"**{text[:colon_position + 1]}**\n" + text[colon_position + 1:] transformed_string = (
f"**{text[:colon_position + 1]}**\n"
+ text[colon_position + 1 :]
)
else: else:
if reference_link: if reference_link:
transformed_string = f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>" + text[colon_position + 1:] transformed_string = (
f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>"
+ text[colon_position + 1 :]
)
else: else:
transformed_string = "<strong>" + text[:colon_position + 1] + "</strong>" +'<br>' + text[colon_position + 1:] transformed_string = (
"<strong>"
+ text[: colon_position + 1]
+ "</strong>"
+ '<br>'
+ text[colon_position + 1 :]
)
else: else:
# If there's no ": ", return the original string # If there's no ": ", return the original string
transformed_string = text transformed_string = text
@ -101,11 +120,14 @@ def unique_strings(input_list: List[str]) -> List[str]:
seen.add(item) seen.add(item)
return unique_list return unique_list
def convert_to_markdown_v2(output_data: dict,
def convert_to_markdown_v2(
output_data: dict,
gfm_supported: bool = True, gfm_supported: bool = True,
incremental_review=None, incremental_review=None,
git_provider=None, git_provider=None,
files=None) -> str: files=None,
) -> str:
""" """
Convert a dictionary of data into markdown format. Convert a dictionary of data into markdown format.
Args: Args:
@ -183,7 +205,9 @@ def convert_to_markdown_v2(output_data: dict,
else: else:
markdown_text += f"### {emoji} PR 包含测试\n\n" markdown_text += f"### {emoji} PR 包含测试\n\n"
elif 'ticket compliance check' in key_nice.lower(): elif 'ticket compliance check' in key_nice.lower():
markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) markdown_text = ticket_markdown_logic(
emoji, markdown_text, value, gfm_supported
)
elif 'security concerns' in key_nice.lower(): elif 'security concerns' in key_nice.lower():
if gfm_supported: if gfm_supported:
markdown_text += f"<tr><td>" markdown_text += f"<tr><td>"
@ -220,7 +244,9 @@ def convert_to_markdown_v2(output_data: dict,
if gfm_supported: if gfm_supported:
markdown_text += f"<tr><td>" markdown_text += f"<tr><td>"
# markdown_text += f"{emoji}&nbsp;<strong>{key_nice}</strong><br><br>\n\n" # markdown_text += f"{emoji}&nbsp;<strong>{key_nice}</strong><br><br>\n\n"
markdown_text += f"{emoji}&nbsp;<strong>建议评审的重点领域</strong><br><br>\n\n" markdown_text += (
f"{emoji}&nbsp;<strong>建议评审的重点领域</strong><br><br>\n\n"
)
else: else:
markdown_text += f"### {emoji} 建议评审的重点领域\n\n#### \n" markdown_text += f"### {emoji} 建议评审的重点领域\n\n#### \n"
for i, issue in enumerate(issues): for i, issue in enumerate(issues):
@ -235,9 +261,13 @@ def convert_to_markdown_v2(output_data: dict,
start_line = int(str(issue.get('start_line', 0)).strip()) start_line = int(str(issue.get('start_line', 0)).strip())
end_line = int(str(issue.get('end_line', 0)).strip()) end_line = int(str(issue.get('end_line', 0)).strip())
relevant_lines_str = extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=True) relevant_lines_str = extract_relevant_lines_str(
end_line, files, relevant_file, start_line, dedent=True
)
if git_provider: if git_provider:
reference_link = git_provider.get_line_link(relevant_file, start_line, end_line) reference_link = git_provider.get_line_link(
relevant_file, start_line, end_line
)
else: else:
reference_link = None reference_link = None
@ -256,7 +286,9 @@ def convert_to_markdown_v2(output_data: dict,
issue_str = f"**{issue_header}**\n\n{issue_content}\n\n" issue_str = f"**{issue_header}**\n\n{issue_content}\n\n"
markdown_text += f"{issue_str}\n\n" markdown_text += f"{issue_str}\n\n"
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to process 'Recommended focus areas for review': {e}") get_logger().exception(
f"Failed to process 'Recommended focus areas for review': {e}"
)
if gfm_supported: if gfm_supported:
markdown_text += f"</td></tr>\n" markdown_text += f"</td></tr>\n"
else: else:
@ -273,7 +305,9 @@ def convert_to_markdown_v2(output_data: dict,
return markdown_text return markdown_text
def extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=False) -> str: def extract_relevant_lines_str(
end_line, files, relevant_file, start_line, dedent=False
) -> str:
""" """
Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content. Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content.
""" """
@ -286,10 +320,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
if not file.head_file: if not file.head_file:
# as a fallback, extract relevant lines directly from patch # as a fallback, extract relevant lines directly from patch
patch = file.patch patch = file.patch
get_logger().info(f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead") get_logger().info(
_, selected_lines = extract_hunk_lines_from_patch(patch, file.filename, start_line, end_line,side='right') f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead"
)
_, selected_lines = extract_hunk_lines_from_patch(
patch, file.filename, start_line, end_line, side='right'
)
if not selected_lines: if not selected_lines:
get_logger().error(f"Failed to extract relevant lines from patch: {file.filename}") get_logger().error(
f"Failed to extract relevant lines from patch: {file.filename}"
)
return "" return ""
# filter out '-' lines # filter out '-' lines
relevant_lines_str = "" relevant_lines_str = ""
@ -299,12 +339,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
relevant_lines_str += line[1:] + '\n' relevant_lines_str += line[1:] + '\n'
else: else:
relevant_file_lines = file.head_file.splitlines() relevant_file_lines = file.head_file.splitlines()
relevant_lines_str = "\n".join(relevant_file_lines[start_line - 1:end_line]) relevant_lines_str = "\n".join(
relevant_file_lines[start_line - 1 : end_line]
)
if dedent and relevant_lines_str: if dedent and relevant_lines_str:
# Remove the longest leading string of spaces and tabs common to all lines. # Remove the longest leading string of spaces and tabs common to all lines.
relevant_lines_str = textwrap.dedent(relevant_lines_str) relevant_lines_str = textwrap.dedent(relevant_lines_str)
relevant_lines_str = f"```{file.language}\n{relevant_lines_str}\n```" relevant_lines_str = (
f"```{file.language}\n{relevant_lines_str}\n```"
)
break break
return relevant_lines_str return relevant_lines_str
@ -325,14 +369,21 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
ticket_url = ticket_analysis.get('ticket_url', '').strip() ticket_url = ticket_analysis.get('ticket_url', '').strip()
explanation = '' explanation = ''
ticket_compliance_level = '' # Individual ticket compliance ticket_compliance_level = '' # Individual ticket compliance
fully_compliant_str = ticket_analysis.get('fully_compliant_requirements', '').strip() fully_compliant_str = ticket_analysis.get(
not_compliant_str = ticket_analysis.get('not_compliant_requirements', '').strip() 'fully_compliant_requirements', ''
requires_further_human_verification = ticket_analysis.get('requires_further_human_verification', ).strip()
'').strip() not_compliant_str = ticket_analysis.get(
'not_compliant_requirements', ''
).strip()
requires_further_human_verification = ticket_analysis.get(
'requires_further_human_verification', ''
).strip()
if not fully_compliant_str and not not_compliant_str: if not fully_compliant_str and not not_compliant_str:
get_logger().debug(f"Ticket compliance has no requirements", get_logger().debug(
artifact={'ticket_url': ticket_url}) f"Ticket compliance has no requirements",
artifact={'ticket_url': ticket_url},
)
continue continue
# Calculate individual ticket compliance level # Calculate individual ticket compliance level
@ -353,19 +404,27 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
# build compliance string # build compliance string
if fully_compliant_str: if fully_compliant_str:
explanation += f"Compliant requirements:\n\n{fully_compliant_str}\n\n" explanation += (
f"Compliant requirements:\n\n{fully_compliant_str}\n\n"
)
if not_compliant_str: if not_compliant_str:
explanation += f"Non-compliant requirements:\n\n{not_compliant_str}\n\n" explanation += (
f"Non-compliant requirements:\n\n{not_compliant_str}\n\n"
)
if requires_further_human_verification: if requires_further_human_verification:
explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n" explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n"
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n" ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n"
# for debugging # for debugging
if requires_further_human_verification: if requires_further_human_verification:
get_logger().debug(f"Ticket compliance requires further human verification", get_logger().debug(
artifact={'ticket_url': ticket_url, f"Ticket compliance requires further human verification",
artifact={
'ticket_url': ticket_url,
'requires_further_human_verification': requires_further_human_verification, 'requires_further_human_verification': requires_further_human_verification,
'compliance_level': ticket_compliance_level}) 'compliance_level': ticket_compliance_level,
},
)
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to process ticket compliance: {e}") get_logger().exception(f"Failed to process ticket compliance: {e}")
@ -381,7 +440,10 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
compliance_emoji = '' compliance_emoji = ''
elif any(level == 'Not compliant' for level in all_compliance_levels): elif any(level == 'Not compliant' for level in all_compliance_levels):
# If there's a mix of compliant and non-compliant tickets # If there's a mix of compliant and non-compliant tickets
if any(level in ['Fully compliant', 'PR Code Verified'] for level in all_compliance_levels): if any(
level in ['Fully compliant', 'PR Code Verified']
for level in all_compliance_levels
):
compliance_level = 'Partially compliant' compliance_level = 'Partially compliant'
compliance_emoji = '🔶' compliance_emoji = '🔶'
else: else:
@ -395,7 +457,9 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
compliance_emoji = '' compliance_emoji = ''
# Set extra statistics outside the ticket loop # Set extra statistics outside the ticket loop
get_settings().set('config.extra_statistics', {'compliance_level': compliance_level}) get_settings().set(
'config.extra_statistics', {'compliance_level': compliance_level}
)
# editing table row for ticket compliance analysis # editing table row for ticket compliance analysis
if gfm_supported: if gfm_supported:
@ -425,7 +489,9 @@ def process_can_be_split(emoji, value):
for i, split in enumerate(value): for i, split in enumerate(value):
title = split.get('title', '') title = split.get('title', '')
relevant_files = split.get('relevant_files', []) relevant_files = split.get('relevant_files', [])
markdown_text += f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n" markdown_text += (
f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n"
)
markdown_text += f"___\n\n相关文件:\n\n" markdown_text += f"___\n\n相关文件:\n\n"
for file in relevant_files: for file in relevant_files:
markdown_text += f"- {file}\n" markdown_text += f"- {file}\n"
@ -464,7 +530,9 @@ def process_can_be_split(emoji, value):
return markdown_text return markdown_text
def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool = True) -> str: def parse_code_suggestion(
code_suggestion: dict, i: int = 0, gfm_supported: bool = True
) -> str:
""" """
Convert a dictionary of data into markdown format. Convert a dictionary of data into markdown format.
@ -484,15 +552,19 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
markdown_text += f"<tr><td>相关文件</td><td>{relevant_file}</td></tr>" markdown_text += f"<tr><td>相关文件</td><td>{relevant_file}</td></tr>"
# continue # continue
elif sub_key.lower() == 'suggestion': elif sub_key.lower() == 'suggestion':
markdown_text += (f"<tr><td>{sub_key} &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>" markdown_text += (
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>") f"<tr><td>{sub_key} &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>"
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>"
)
elif sub_key.lower() == 'relevant_line': elif sub_key.lower() == 'relevant_line':
markdown_text += f"<tr><td>相关行</td>" markdown_text += f"<tr><td>相关行</td>"
sub_value_list = sub_value.split('](') sub_value_list = sub_value.split('](')
relevant_line = sub_value_list[0].lstrip('`').lstrip('[') relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
if len(sub_value_list) > 1: if len(sub_value_list) > 1:
link = sub_value_list[1].rstrip(')').strip('`') link = sub_value_list[1].rstrip(')').strip('`')
markdown_text += f"<td><a href='{link}'>{relevant_line}</a></td>" markdown_text += (
f"<td><a href='{link}'>{relevant_line}</a></td>"
)
else: else:
markdown_text += f"<td>{relevant_line}</td>" markdown_text += f"<td>{relevant_line}</td>"
markdown_text += "</tr>" markdown_text += "</tr>"
@ -509,7 +581,10 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
sub_value = sub_value.rstrip() sub_value = sub_value.rstrip()
if isinstance(sub_value, dict): # "code example" if isinstance(sub_value, dict): # "code example"
markdown_text += f" - **{sub_key}:**\n" markdown_text += f" - **{sub_key}:**\n"
for code_key, code_value in sub_value.items(): # 'before' and 'after' code for (
code_key,
code_value,
) in sub_value.items(): # 'before' and 'after' code
code_str = f"```\n{code_value}\n```" code_str = f"```\n{code_value}\n```"
code_str_indented = textwrap.indent(code_str, ' ') code_str_indented = textwrap.indent(code_str, ' ')
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n" markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
@ -520,7 +595,9 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
markdown_text += f" **{sub_key}:** {sub_value} \n" markdown_text += f" **{sub_key}:** {sub_value} \n"
if "relevant_line" not in sub_key.lower(): # nicer presentation if "relevant_line" not in sub_key.lower(): # nicer presentation
# markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab # markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab
markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker markdown_text = (
markdown_text.rstrip('\n') + " \n"
) # works for gitlab and bitbucker
markdown_text += "\n" markdown_text += "\n"
return markdown_text return markdown_text
@ -561,9 +638,15 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
else: else:
closing_bracket = "]}}" closing_bracket = "]}}"
if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \ if (
(review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) : review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 ) or (
review.rfind("'Code suggestions': [") > 0
or review.rfind('"Code suggestions": [') > 0
):
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][
-1
] - 1
valid_json = False valid_json = False
iter_count = 0 iter_count = 0
@ -574,7 +657,9 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
review = review[:last_code_suggestion_ind].strip() + closing_bracket review = review[:last_code_suggestion_ind].strip() + closing_bracket
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
review = review[:last_code_suggestion_ind] review = review[:last_code_suggestion_ind]
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 last_code_suggestion_ind = [
m.end() for m in re.finditer(r"\}\s*,", review)
][-1] - 1
iter_count += 1 iter_count += 1
if not valid_json: if not valid_json:
@ -629,7 +714,12 @@ def convert_str_to_datetime(date_str):
return datetime.strptime(date_str, datetime_format) return datetime.strptime(date_str, datetime_format)
def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str, show_warning: bool = True) -> str: def load_large_diff(
filename,
new_file_content_str: str,
original_file_content_str: str,
show_warning: bool = True,
) -> str:
""" """
Generate a patch for a modified file by comparing the original content of the file with the new content provided as Generate a patch for a modified file by comparing the original content of the file with the new content provided as
input. input.
@ -640,10 +730,14 @@ def load_large_diff(filename, new_file_content_str: str, original_file_content_s
try: try:
original_file_content_str = (original_file_content_str or "").rstrip() + "\n" original_file_content_str = (original_file_content_str or "").rstrip() + "\n"
new_file_content_str = (new_file_content_str or "").rstrip() + "\n" new_file_content_str = (new_file_content_str or "").rstrip() + "\n"
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), diff = difflib.unified_diff(
new_file_content_str.splitlines(keepends=True)) original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True),
)
if get_settings().config.verbosity_level >= 2 and show_warning: if get_settings().config.verbosity_level >= 2 and show_warning:
get_logger().info(f"File was modified, but no patch was found. Manually creating patch: {filename}.") get_logger().info(
f"File was modified, but no patch was found. Manually creating patch: {filename}."
)
patch = ''.join(diff) patch = ''.join(diff)
return patch return patch
except Exception as e: except Exception as e:
@ -693,42 +787,68 @@ def _fix_key_value(key: str, value: str):
try: try:
value = yaml.safe_load(value) value = yaml.safe_load(value)
except Exception as e: except Exception as e:
get_logger().debug(f"Failed to parse YAML for config override {key}={value}", exc_info=e) get_logger().debug(
f"Failed to parse YAML for config override {key}={value}", exc_info=e
)
return key, value return key, value
def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="") -> dict: def load_yaml(
response_text = response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```') response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key=""
) -> dict:
response_text = (
response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```')
)
try: try:
data = yaml.safe_load(response_text) data = yaml.safe_load(response_text)
except Exception as e: except Exception as e:
get_logger().warning(f"Initial failure to parse AI prediction: {e}") get_logger().warning(f"Initial failure to parse AI prediction: {e}")
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key) data = try_fix_yaml(
response_text,
keys_fix_yaml=keys_fix_yaml,
first_key=first_key,
last_key=last_key,
)
if not data: if not data:
get_logger().error(f"Failed to parse AI prediction after fallbacks", get_logger().error(
artifact={'response_text': response_text}) f"Failed to parse AI prediction after fallbacks",
artifact={'response_text': response_text},
)
else: else:
get_logger().info(f"Successfully parsed AI prediction after fallbacks", get_logger().info(
artifact={'response_text': response_text}) f"Successfully parsed AI prediction after fallbacks",
artifact={'response_text': response_text},
)
return data return data
def try_fix_yaml(
def try_fix_yaml(response_text: str, response_text: str,
keys_fix_yaml: List[str] = [], keys_fix_yaml: List[str] = [],
first_key="", first_key="",
last_key="",) -> dict: last_key="",
) -> dict:
response_text_lines = response_text.split('\n') response_text_lines = response_text.split('\n')
keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:'] keys_yaml = [
'relevant line:',
'suggestion content:',
'relevant file:',
'existing code:',
'improved code:',
]
keys_yaml = keys_yaml + keys_fix_yaml keys_yaml = keys_yaml + keys_fix_yaml
# first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...' # first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...'
response_text_lines_copy = response_text_lines.copy() response_text_lines_copy = response_text_lines.copy()
for i in range(0, len(response_text_lines_copy)): for i in range(0, len(response_text_lines_copy)):
for key in keys_yaml: for key in keys_yaml:
if key in response_text_lines_copy[i] and not '|' in response_text_lines_copy[i]: if (
response_text_lines_copy[i] = response_text_lines_copy[i].replace(f'{key}', key in response_text_lines_copy[i]
f'{key} |\n ') and not '|' in response_text_lines_copy[i]
):
response_text_lines_copy[i] = response_text_lines_copy[i].replace(
f'{key}', f'{key} |\n '
)
try: try:
data = yaml.safe_load('\n'.join(response_text_lines_copy)) data = yaml.safe_load('\n'.join(response_text_lines_copy))
get_logger().info(f"Successfully parsed AI prediction after adding |-\n") get_logger().info(f"Successfully parsed AI prediction after adding |-\n")
@ -743,22 +863,26 @@ def try_fix_yaml(response_text: str,
snippet_text = snippet.group() snippet_text = snippet.group()
try: try:
data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`')) data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`'))
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet") get_logger().info(
f"Successfully parsed AI prediction after extracting yaml snippet"
)
return data return data
except: except:
pass pass
# third fallback - try to remove leading and trailing curly brackets # third fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n') response_text_copy = (
response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
)
try: try:
data = yaml.safe_load(response_text_copy) data = yaml.safe_load(response_text_copy)
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets") get_logger().info(
f"Successfully parsed AI prediction after removing curly brackets"
)
return data return data
except: except:
pass pass
# forth fallback - try to extract yaml snippet by 'first_key' and 'last_key' # forth fallback - try to extract yaml snippet by 'first_key' and 'last_key'
# note that 'last_key' can be in practice a key that is not the last key in the yaml snippet. # note that 'last_key' can be in practice a key that is not the last key in the yaml snippet.
# it just needs to be some inner key, so we can look for newlines after it # it just needs to be some inner key, so we can look for newlines after it
@ -767,13 +891,23 @@ def try_fix_yaml(response_text: str,
if index_start == -1: if index_start == -1:
index_start = response_text.find(f"{first_key}:") index_start = response_text.find(f"{first_key}:")
index_last_code = response_text.rfind(f"{last_key}:") index_last_code = response_text.rfind(f"{last_key}:")
index_end = response_text.find("\n\n", index_last_code) # look for newlines after last_key index_end = response_text.find(
"\n\n", index_last_code
) # look for newlines after last_key
if index_end == -1: if index_end == -1:
index_end = len(response_text) index_end = len(response_text)
response_text_copy = response_text[index_start:index_end].strip().strip('```yaml').strip('`').strip() response_text_copy = (
response_text[index_start:index_end]
.strip()
.strip('```yaml')
.strip('`')
.strip()
)
try: try:
data = yaml.safe_load(response_text_copy) data = yaml.safe_load(response_text_copy)
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet") get_logger().info(
f"Successfully parsed AI prediction after extracting yaml snippet"
)
return data return data
except: except:
pass pass
@ -784,7 +918,9 @@ def try_fix_yaml(response_text: str,
response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:] response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:]
try: try:
data = yaml.safe_load('\n'.join(response_text_lines_copy)) data = yaml.safe_load('\n'.join(response_text_lines_copy))
get_logger().info(f"Successfully parsed AI prediction after removing leading '+'") get_logger().info(
f"Successfully parsed AI prediction after removing leading '+'"
)
return data return data
except: except:
pass pass
@ -794,7 +930,9 @@ def try_fix_yaml(response_text: str,
response_text_lines_tmp = '\n'.join(response_text_lines[:-i]) response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
try: try:
data = yaml.safe_load(response_text_lines_tmp) data = yaml.safe_load(response_text_lines_tmp)
get_logger().info(f"Successfully parsed AI prediction after removing {i} lines") get_logger().info(
f"Successfully parsed AI prediction after removing {i} lines"
)
return data return data
except: except:
pass pass
@ -820,11 +958,14 @@ def set_custom_labels(variables, git_provider=None):
for k, v in labels.items(): for k, v in labels.items():
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'" description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}" # variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}" variables[
"custom_labels_class"
] += f"\n {k.lower().replace(' ', '_')} = {description}"
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
counter += 1 counter += 1
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
def get_user_labels(current_labels: List[str] = None): def get_user_labels(current_labels: List[str] = None):
""" """
Only keep labels that has been added by the user Only keep labels that has been added by the user
@ -866,14 +1007,22 @@ def get_max_tokens(model):
elif settings.config.custom_model_max_tokens > 0: elif settings.config.custom_model_max_tokens > 0:
max_tokens_model = settings.config.custom_model_max_tokens max_tokens_model = settings.config.custom_model_max_tokens
else: else:
raise Exception(f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens") raise Exception(
f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens"
)
if settings.config.max_model_tokens and settings.config.max_model_tokens > 0: if settings.config.max_model_tokens and settings.config.max_model_tokens > 0:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
return max_tokens_model return max_tokens_model
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str: def clip_tokens(
text: str,
max_tokens: int,
add_three_dots=True,
num_input_tokens=None,
delete_last_line=False,
) -> str:
""" """
Clip the number of tokens in a string to a maximum number of tokens. Clip the number of tokens in a string to a maximum number of tokens.
@ -917,6 +1066,7 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token
get_logger().warning(f"Failed to clip tokens: {e}") get_logger().warning(f"Failed to clip tokens: {e}")
return text return text
def replace_code_tags(text): def replace_code_tags(text):
""" """
Replace odd instances of ` with <code> and even instances of ` with </code> Replace odd instances of ` with <code> and even instances of ` with </code>
@ -928,15 +1078,16 @@ def replace_code_tags(text):
return ''.join(parts) return ''.join(parts)
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], def find_line_number_of_relevant_line_in_file(
diff_files: List[FilePatchInfo],
relevant_file: str, relevant_file: str,
relevant_line_in_file: str, relevant_line_in_file: str,
absolute_position: int = None) -> Tuple[int, int]: absolute_position: int = None,
) -> Tuple[int, int]:
position = -1 position = -1
if absolute_position is None: if absolute_position is None:
absolute_position = -1 absolute_position = -1
re_hunk_header = re.compile( re_hunk_header = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
if not diff_files: if not diff_files:
return position, absolute_position return position, absolute_position
@ -965,12 +1116,12 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break break
else: else:
# try to find the line in the patch using difflib, with some margin of error # try to find the line in the patch using difflib, with some margin of error
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file, matches_difflib: list[str | Any] = difflib.get_close_matches(
patch_lines, n=3, cutoff=0.93) relevant_line_in_file, patch_lines, n=3, cutoff=0.93
)
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'): if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
relevant_line_in_file = matches_difflib[0] relevant_line_in_file = matches_difflib[0]
for i, line in enumerate(patch_lines): for i, line in enumerate(patch_lines):
if line.startswith('@@'): if line.startswith('@@'):
delta = 0 delta = 0
@ -1002,19 +1153,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break break
return position, absolute_position return position, absolute_position
def get_rate_limit_status(github_token) -> dict: def get_rate_limit_status(github_token) -> dict:
GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com" GITHUB_API_URL = (
get_settings(use_context=False)
.get("GITHUB.BASE_URL", "https://api.github.com")
.rstrip("/")
) # "https://api.github.com"
# GITHUB_API_URL = "https://api.github.com" # GITHUB_API_URL = "https://api.github.com"
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit" RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
HEADERS = { HEADERS = {
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}" "Authorization": f"token {github_token}",
} }
response = requests.get(RATE_LIMIT_URL, headers=HEADERS) response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
try: try:
rate_limit_info = response.json() rate_limit_info = response.json()
if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise if (
rate_limit_info.get('message') == 'Rate limiting is not enabled.'
): # for github enterprise
return {'resources': {}} return {'resources': {}}
response.raise_for_status() # Check for HTTP errors response.raise_for_status() # Check for HTTP errors
except: # retry except: # retry
@ -1024,11 +1182,15 @@ def get_rate_limit_status(github_token) -> dict:
return rate_limit_info return rate_limit_info
def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool: def validate_rate_limit_github(
github_token, installation_id=None, threshold=0.1
) -> bool:
try: try:
rate_limit_status = get_rate_limit_status(github_token) rate_limit_status = get_rate_limit_status(github_token)
if installation_id: if installation_id:
get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}") get_logger().debug(
f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}"
)
# validate that the rate limit is not exceeded # validate that the rate limit is not exceeded
# validate that the rate limit is not exceeded # validate that the rate limit is not exceeded
for key, value in rate_limit_status['resources'].items(): for key, value in rate_limit_status['resources'].items():
@ -1037,8 +1199,9 @@ def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1
return False return False
return True return True
except Exception as e: except Exception as e:
get_logger().error(f"Error in rate limit {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Error in rate limit {e}", artifact={"traceback": traceback.format_exc()}
)
return True return True
@ -1051,7 +1214,9 @@ def validate_and_await_rate_limit(github_token):
get_logger().error(f"key: {key}, value: {value}") get_logger().error(f"key: {key}, value: {value}")
sleep_time_sec = value['reset'] - datetime.now().timestamp() sleep_time_sec = value['reset'] - datetime.now().timestamp()
sleep_time_hour = sleep_time_sec / 3600.0 sleep_time_hour = sleep_time_sec / 3600.0
get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours") get_logger().error(
f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours"
)
if sleep_time_sec > 0: if sleep_time_sec > 0:
time.sleep(sleep_time_sec + 1) time.sleep(sleep_time_sec + 1)
rate_limit_status = get_rate_limit_status(github_token) rate_limit_status = get_rate_limit_status(github_token)
@ -1068,21 +1233,38 @@ def github_action_output(output_data: dict, key_name: str):
key_data = output_data.get(key_name, {}) key_data = output_data.get(key_name, {})
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh) print(
f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}",
file=fh,
)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to write to GitHub Action output: {e}") get_logger().error(f"Failed to write to GitHub Action output: {e}")
return return
def show_relevant_configurations(relevant_section: str) -> str: def show_relevant_configurations(relevant_section: str) -> str:
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect", skip_keys = [
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME'] 'ai_disclaimer',
'ai_disclaimer_title',
'ANALYTICS_FOLDER',
'secret_provider',
"skip_keys",
"app_id",
"redirect",
'trial_prefix_message',
'no_eligible_message',
'identity_provider',
'ALLOWED_REPOS',
'APP_NAME',
]
extra_skip_keys = get_settings().config.get('config.skip_keys', []) extra_skip_keys = get_settings().config.get('config.skip_keys', [])
if extra_skip_keys: if extra_skip_keys:
skip_keys.extend(extra_skip_keys) skip_keys.extend(extra_skip_keys)
markdown_text = "" markdown_text = ""
markdown_text += "\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n" markdown_text += (
"\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n"
)
markdown_text += "<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n" markdown_text += "<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n"
markdown_text += f"**[config**]\n```yaml\n\n" markdown_text += f"**[config**]\n```yaml\n\n"
for key, value in get_settings().config.items(): for key, value in get_settings().config.items():
@ -1099,6 +1281,7 @@ def show_relevant_configurations(relevant_section: str) -> str:
markdown_text += "\n</details>\n" markdown_text += "\n</details>\n"
return markdown_text return markdown_text
def is_value_no(value): def is_value_no(value):
if not value: if not value:
return True return True
@ -1131,7 +1314,9 @@ def process_description(description_full: str) -> Tuple[str, List]:
if not description_full: if not description_full:
return "", [] return "", []
description_split = description_full.split(PRDescriptionHeader.CHANGES_WALKTHROUGH.value) description_split = description_full.split(
PRDescriptionHeader.CHANGES_WALKTHROUGH.value
)
base_description_str = description_split[0] base_description_str = description_split[0]
changes_walkthrough_str = "" changes_walkthrough_str = ""
files = [] files = []
@ -1174,38 +1359,51 @@ def process_description(description_full: str) -> Tuple[str, List]:
short_summary = res.group(2).strip() short_summary = res.group(2).strip()
long_filename = res.group(3).strip() long_filename = res.group(3).strip()
long_summary = res.group(4).strip() long_summary = res.group(4).strip()
long_summary = long_summary.replace('<br> *', '\n*').replace('<br>','').replace('\n','<br>') long_summary = (
long_summary.replace('<br> *', '\n*')
.replace('<br>', '')
.replace('\n', '<br>')
)
long_summary = h.handle(long_summary).strip() long_summary = h.handle(long_summary).strip()
if long_summary.startswith('\\-'): if long_summary.startswith('\\-'):
long_summary = "* " + long_summary[2:] long_summary = "* " + long_summary[2:]
elif not long_summary.startswith('*'): elif not long_summary.startswith('*'):
long_summary = f"* {long_summary}" long_summary = f"* {long_summary}"
files.append({ files.append(
{
'short_file_name': short_filename, 'short_file_name': short_filename,
'full_file_name': long_filename, 'full_file_name': long_filename,
'short_summary': short_summary, 'short_summary': short_summary,
'long_summary': long_summary 'long_summary': long_summary,
}) }
)
else: else:
if '<code>...</code>' in file_data: if '<code>...</code>' in file_data:
pass # PR with many files. some did not get analyzed pass # PR with many files. some did not get analyzed
else: else:
get_logger().error(f"Failed to parse description", artifact={'description': file_data}) get_logger().error(
f"Failed to parse description",
artifact={'description': file_data},
)
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data}) get_logger().exception(
f"Failed to process description: {e}",
artifact={'description': file_data},
)
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to process description: {e}") get_logger().exception(f"Failed to process description: {e}")
return base_description_str, files return base_description_str, files
def get_version() -> str: def get_version() -> str:
# First check pyproject.toml if running directly out of repository # First check pyproject.toml if running directly out of repository
if os.path.exists("pyproject.toml"): if os.path.exists("pyproject.toml"):
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
import tomllib import tomllib
with open("pyproject.toml", "rb") as f: with open("pyproject.toml", "rb") as f:
data = tomllib.load(f) data = tomllib.load(f)
if "project" in data and "version" in data["project"]: if "project" in data and "version" in data["project"]:
@ -1213,7 +1411,9 @@ def get_version() -> str:
else: else:
get_logger().warning("Version not found in pyproject.toml") get_logger().warning("Version not found in pyproject.toml")
else: else:
get_logger().warning("Unable to determine local version from pyproject.toml") get_logger().warning(
"Unable to determine local version from pyproject.toml"
)
# Otherwise get the installed pip package version # Otherwise get the installed pip package version
try: try:

View File

@ -12,8 +12,9 @@ setup_logger(log_level)
def set_parser(): def set_parser():
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage= parser = argparse.ArgumentParser(
"""\ description='AI based pull request analyzer',
usage="""\
Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>]. Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>].
For example: For example:
- cli.py --pr_url=... review - cli.py --pr_url=... review
@ -45,11 +46,20 @@ def set_parser():
Configuration: Configuration:
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>. To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""") """,
parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}') )
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None) parser.add_argument(
parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None) '--version', action='version', version=f'pr-agent {get_version()}'
parser.add_argument('command', type=str, help='The', choices=commands, default='review') )
parser.add_argument(
'--pr_url', type=str, help='The URL of the PR to review', default=None
)
parser.add_argument(
'--issue_url', type=str, help='The URL of the Issue to review', default=None
)
parser.add_argument(
'command', type=str, help='The', choices=commands, default='review'
)
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
return parser return parser
@ -76,14 +86,24 @@ def run(inargs=None, args=None):
async def inner(): async def inner():
if args.issue_url: if args.issue_url:
result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest)) result = await asyncio.create_task(
PRAgent().handle_request(args.issue_url, [command] + args.rest)
)
else: else:
result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest)) result = await asyncio.create_task(
PRAgent().handle_request(args.pr_url, [command] + args.rest)
)
if get_settings().litellm.get("enable_callbacks", False): if get_settings().litellm.get("enable_callbacks", False):
# There may be additional events on the event queue from the run above. If there are give them time to complete. # There may be additional events on the event queue from the run above. If there are give them time to complete.
get_logger().debug("Waiting for event queue to complete") get_logger().debug("Waiting for event queue to complete")
await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()]) await asyncio.wait(
[
task
for task in asyncio.all_tasks()
if task is not asyncio.current_task()
]
)
return result return result

View File

@ -7,7 +7,9 @@ def main():
provider = "github" # GitHub provider provider = "github" # GitHub provider
user_token = "..." # GitHub user token user_token = "..." # GitHub user token
openai_key = "..." # OpenAI key openai_key = "..." # OpenAI key
pr_url = "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809' pr_url = (
"..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
)
command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"') command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"')
# Setting the configurations # Setting the configurations

View File

@ -11,7 +11,9 @@ current_dir = dirname(abspath(__file__))
global_settings = Dynaconf( global_settings = Dynaconf(
envvar_prefix=False, envvar_prefix=False,
merge_enabled=True, merge_enabled=True,
settings_files=[join(current_dir, f) for f in [ settings_files=[
join(current_dir, f)
for f in [
"settings/configuration.toml", "settings/configuration.toml",
"settings/ignore.toml", "settings/ignore.toml",
"settings/language_extensions.toml", "settings/language_extensions.toml",
@ -30,7 +32,8 @@ global_settings = Dynaconf(
"settings/pr_help_prompts.toml", "settings/pr_help_prompts.toml",
"settings/.secrets.toml", "settings/.secrets.toml",
"settings_prod/.secrets.toml", "settings_prod/.secrets.toml",
]] ]
],
) )

View File

@ -3,8 +3,9 @@ from starlette_context import context
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from utils.pr_agent.git_providers.bitbucket_server_provider import \ from utils.pr_agent.git_providers.bitbucket_server_provider import (
BitbucketServerProvider BitbucketServerProvider,
)
from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider
from utils.pr_agent.git_providers.gerrit_provider import GerritProvider from utils.pr_agent.git_providers.gerrit_provider import GerritProvider
from utils.pr_agent.git_providers.git_provider import GitProvider from utils.pr_agent.git_providers.git_provider import GitProvider
@ -28,7 +29,9 @@ def get_git_provider():
try: try:
provider_id = get_settings().config.git_provider provider_id = get_settings().config.git_provider
except AttributeError as e: except AttributeError as e:
raise ValueError("git_provider is a required attribute in the configuration file") from e raise ValueError(
"git_provider is a required attribute in the configuration file"
) from e
if provider_id not in _GIT_PROVIDERS: if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}") raise ValueError(f"Unknown git provider: {provider_id}")
return _GIT_PROVIDERS[provider_id] return _GIT_PROVIDERS[provider_id]

View File

@ -6,8 +6,11 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.file_filter import filter_ignored from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import (PRDescriptionHeader, find_line_number_of_relevant_line_in_file, from ..algo.utils import (
load_large_diff) PRDescriptionHeader,
find_line_number_of_relevant_line_in_file,
load_large_diff,
)
from ..config_loader import get_settings from ..config_loader import get_settings
from ..log import get_logger from ..log import get_logger
from .git_provider import GitProvider from .git_provider import GitProvider
@ -20,11 +23,16 @@ try:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from azure.devops.connection import Connection from azure.devops.connection import Connection
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from azure.devops.v7_1.git.models import (Comment, CommentThread, from azure.devops.v7_1.git.models import (
Comment,
CommentThread,
GitPullRequest, GitPullRequest,
GitPullRequestIterationChanges, GitPullRequestIterationChanges,
GitVersionDescriptor) GitVersionDescriptor,
)
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from azure.identity import DefaultAzureCredential from azure.identity import DefaultAzureCredential
from msrest.authentication import BasicAuthentication from msrest.authentication import BasicAuthentication
@ -33,7 +41,6 @@ except ImportError:
class AzureDevopsProvider(GitProvider): class AzureDevopsProvider(GitProvider):
def __init__( def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
): ):
@ -67,13 +74,16 @@ class AzureDevopsProvider(GitProvider):
if not relevant_lines_start or relevant_lines_start == -1: if not relevant_lines_start or relevant_lines_start == -1:
get_logger().warning( get_logger().warning(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
)
continue continue
if relevant_lines_end < relevant_lines_start: if relevant_lines_end < relevant_lines_start:
get_logger().warning(f"Failed to publish code suggestion, " get_logger().warning(
f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and " f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}") f"relevant_lines_start is {relevant_lines_start}"
)
continue continue
if relevant_lines_end > relevant_lines_start: if relevant_lines_end > relevant_lines_start:
@ -98,7 +108,8 @@ class AzureDevopsProvider(GitProvider):
for post_parameters in post_parameters_list: for post_parameters in post_parameters_list:
try: try:
comment = Comment(content=post_parameters["body"], comment_type=1) comment = Comment(content=post_parameters["body"], comment_type=1)
thread = CommentThread(comments=[comment], thread = CommentThread(
comments=[comment],
thread_context={ thread_context={
"filePath": post_parameters["path"], "filePath": post_parameters["path"],
"rightFileStart": { "rightFileStart": {
@ -109,19 +120,20 @@ class AzureDevopsProvider(GitProvider):
"line": post_parameters["line"], "line": post_parameters["line"],
"offset": 1, "offset": 1,
}, },
}) },
)
self.azure_devops_client.create_thread( self.azure_devops_client.create_thread(
comment_thread=thread, comment_thread=thread,
project=self.workspace_slug, project=self.workspace_slug,
repository_id=self.repo_slug, repository_id=self.repo_slug,
pull_request_id=self.pr_num pull_request_id=self.pr_num,
) )
except Exception as e: except Exception as e:
get_logger().warning(f"Azure failed to publish code suggestion, error: {e}") get_logger().warning(
f"Azure failed to publish code suggestion, error: {e}"
)
return True return True
def get_pr_description_full(self) -> str: def get_pr_description_full(self) -> str:
return self.pr.description return self.pr.description
@ -220,7 +232,6 @@ class AzureDevopsProvider(GitProvider):
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
try: try:
if self.diff_files: if self.diff_files:
return self.diff_files return self.diff_files
@ -231,18 +242,20 @@ class AzureDevopsProvider(GitProvider):
iterations = self.azure_devops_client.get_pull_request_iterations( iterations = self.azure_devops_client.get_pull_request_iterations(
repository_id=self.repo_slug, repository_id=self.repo_slug,
pull_request_id=self.pr_num, pull_request_id=self.pr_num,
project=self.workspace_slug project=self.workspace_slug,
) )
changes = None changes = None
if iterations: if iterations:
iteration_id = iterations[-1].id # Get the last iteration (most recent changes) iteration_id = iterations[
-1
].id # Get the last iteration (most recent changes)
# Get changes for the iteration # Get changes for the iteration
changes = self.azure_devops_client.get_pull_request_iteration_changes( changes = self.azure_devops_client.get_pull_request_iteration_changes(
repository_id=self.repo_slug, repository_id=self.repo_slug,
pull_request_id=self.pr_num, pull_request_id=self.pr_num,
iteration_id=iteration_id, iteration_id=iteration_id,
project=self.workspace_slug project=self.workspace_slug,
) )
diff_files = [] diff_files = []
diffs = [] diffs = []
@ -253,7 +266,9 @@ class AzureDevopsProvider(GitProvider):
path = item.get('path', None) path = item.get('path', None)
if path: if path:
diffs.append(path) diffs.append(path)
diff_types[path] = change.additional_properties.get('changeType', 'Unknown') diff_types[path] = change.additional_properties.get(
'changeType', 'Unknown'
)
# wrong implementation - gets all the files that were changed in any commit in the PR # wrong implementation - gets all the files that were changed in any commit in the PR
# commits = self.azure_devops_client.get_pull_request_commits( # commits = self.azure_devops_client.get_pull_request_commits(
@ -284,9 +299,13 @@ class AzureDevopsProvider(GitProvider):
diffs = filter_ignored(diffs_original, 'azure') diffs = filter_ignored(diffs_original, 'azure')
if diffs_original != diffs: if diffs_original != diffs:
try: try:
get_logger().info(f"Filtered out [ignore] files for pull request:", extra= get_logger().info(
{"files": diffs_original, # diffs is just a list of names f"Filtered out [ignore] files for pull request:",
"filtered_files": diffs}) extra={
"files": diffs_original, # diffs is just a list of names
"filtered_files": diffs,
},
)
except Exception: except Exception:
pass pass
@ -311,7 +330,10 @@ class AzureDevopsProvider(GitProvider):
new_file_content_str = new_file_content_str.content new_file_content_str = new_file_content_str.content
except Exception as error: except Exception as error:
get_logger().error(f"Failed to retrieve new file content of {file} at version {version}", error=error) get_logger().error(
f"Failed to retrieve new file content of {file} at version {version}",
error=error,
)
# get_logger().error( # get_logger().error(
# "Failed to retrieve new file content of %s at version %s. Error: %s", # "Failed to retrieve new file content of %s at version %s. Error: %s",
# file, # file,
@ -325,7 +347,9 @@ class AzureDevopsProvider(GitProvider):
edit_type = EDIT_TYPE.ADDED edit_type = EDIT_TYPE.ADDED
elif diff_types[file] == "delete": elif diff_types[file] == "delete":
edit_type = EDIT_TYPE.DELETED edit_type = EDIT_TYPE.DELETED
elif "rename" in diff_types[file]: # diff_type can be `rename` | `edit, rename` elif (
"rename" in diff_types[file]
): # diff_type can be `rename` | `edit, rename`
edit_type = EDIT_TYPE.RENAMED edit_type = EDIT_TYPE.RENAMED
version = GitVersionDescriptor( version = GitVersionDescriptor(
@ -345,17 +369,27 @@ class AzureDevopsProvider(GitProvider):
) )
original_file_content_str = original_file_content_str.content original_file_content_str = original_file_content_str.content
except Exception as error: except Exception as error:
get_logger().error(f"Failed to retrieve original file content of {file} at version {version}", error=error) get_logger().error(
f"Failed to retrieve original file content of {file} at version {version}",
error=error,
)
original_file_content_str = "" original_file_content_str = ""
patch = load_large_diff( patch = load_large_diff(
file, new_file_content_str, original_file_content_str, show_warning=False file,
new_file_content_str,
original_file_content_str,
show_warning=False,
).rstrip() ).rstrip()
# count number of lines added and removed # count number of lines added and removed
patch_lines = patch.splitlines(keepends=True) patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')]) num_plus_lines = len(
num_minus_lines = len([line for line in patch_lines if line.startswith('-')]) [line for line in patch_lines if line.startswith('+')]
)
num_minus_lines = len(
[line for line in patch_lines if line.startswith('-')]
)
diff_files.append( diff_files.append(
FilePatchInfo( FilePatchInfo(
@ -376,26 +410,34 @@ class AzureDevopsProvider(GitProvider):
get_logger().exception(f"Failed to get diff files, error: {e}") get_logger().exception(f"Failed to get diff files, error: {e}")
return [] return []
def publish_comment(self, pr_comment: str, is_temporary: bool = False, thread_context=None): def publish_comment(
self, pr_comment: str, is_temporary: bool = False, thread_context=None
):
if is_temporary and not get_settings().config.publish_output_progress: if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}") get_logger().debug(
f"Skipping publish_comment for temporary comment: {pr_comment}"
)
return None return None
comment = Comment(content=pr_comment) comment = Comment(content=pr_comment)
thread = CommentThread(comments=[comment], thread_context=thread_context, status=5) thread = CommentThread(
comments=[comment], thread_context=thread_context, status=5
)
thread_response = self.azure_devops_client.create_thread( thread_response = self.azure_devops_client.create_thread(
comment_thread=thread, comment_thread=thread,
project=self.workspace_slug, project=self.workspace_slug,
repository_id=self.repo_slug, repository_id=self.repo_slug,
pull_request_id=self.pr_num, pull_request_id=self.pr_num,
) )
response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id} response = {
"thread_id": thread_response.id,
"comment_id": thread_response.comments[0].id,
}
if is_temporary: if is_temporary:
self.temp_comments.append(response) self.temp_comments.append(response)
return response return response
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH: if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
usage_guide_text = '<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>' usage_guide_text = '<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
ind = pr_body.find(usage_guide_text) ind = pr_body.find(usage_guide_text)
if ind != -1: if ind != -1:
@ -409,7 +451,10 @@ class AzureDevopsProvider(GitProvider):
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH: if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
trunction_message = " ... (description truncated due to length limit)" trunction_message = " ... (description truncated due to length limit)"
pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message pr_body = (
pr_body[: MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)]
+ trunction_message
)
get_logger().warning("PR description was truncated due to length limit") get_logger().warning("PR description was truncated due to length limit")
try: try:
updated_pr = GitPullRequest() updated_pr = GitPullRequest()
@ -433,30 +478,58 @@ class AzureDevopsProvider(GitProvider):
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to remove temp comments, error: {e}") get_logger().exception(f"Failed to remove temp comments, error: {e}")
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): def publish_inline_comment(
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)]) self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
self.publish_inline_comments(
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
)
def create_inline_comment(
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, self,
absolute_position: int = None): body: str,
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(), relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
position, absolute_position = find_line_number_of_relevant_line_in_file(
self.get_diff_files(),
relevant_file.strip('`'), relevant_file.strip('`'),
relevant_line_in_file, relevant_line_in_file,
absolute_position) absolute_position,
)
if position == -1: if position == -1:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE" subject_type = "FILE"
else: else:
subject_type = "LINE" subject_type = "LINE"
path = relevant_file.strip() path = relevant_file.strip()
return dict(body=body, path=path, position=position, absolute_position=absolute_position) if subject_type == "LINE" else {} return (
dict(
body=body,
path=path,
position=position,
absolute_position=absolute_position,
)
if subject_type == "LINE"
else {}
)
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False): def publish_inline_comments(
self, comments: list[dict], disable_fallback: bool = False
):
overall_success = True overall_success = True
for comment in comments: for comment in comments:
try: try:
self.publish_comment(comment["body"], self.publish_comment(
comment["body"],
thread_context={ thread_context={
"filePath": comment["path"], "filePath": comment["path"],
"rightFileStart": { "rightFileStart": {
@ -467,7 +540,8 @@ class AzureDevopsProvider(GitProvider):
"line": comment["absolute_position"], "line": comment["absolute_position"],
"offset": comment["position"], "offset": comment["position"],
}, },
}) },
)
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info( get_logger().info(
f"Published code suggestion on {self.pr_num} at {comment['path']}" f"Published code suggestion on {self.pr_num} at {comment['path']}"
@ -521,7 +595,11 @@ class AzureDevopsProvider(GitProvider):
return 0 return 0
def get_issue_comments(self): def get_issue_comments(self):
threads = self.azure_devops_client.get_threads(repository_id=self.repo_slug, pull_request_id=self.pr_num, project=self.workspace_slug) threads = self.azure_devops_client.get_threads(
repository_id=self.repo_slug,
pull_request_id=self.pr_num,
project=self.workspace_slug,
)
threads.reverse() threads.reverse()
comment_list = [] comment_list = []
for thread in threads: for thread in threads:
@ -532,7 +610,9 @@ class AzureDevopsProvider(GitProvider):
comment_list.append(comment) comment_list.append(comment)
return comment_list return comment_list
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -547,16 +627,22 @@ class AzureDevopsProvider(GitProvider):
raise ValueError( raise ValueError(
"The provided URL does not appear to be a Azure DevOps PR URL" "The provided URL does not appear to be a Azure DevOps PR URL"
) )
if len(path_parts) == 6: # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1" if (
len(path_parts) == 6
): # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
workspace_slug = path_parts[1] workspace_slug = path_parts[1]
repo_slug = path_parts[3] repo_slug = path_parts[3]
pr_number = int(path_parts[5]) pr_number = int(path_parts[5])
elif len(path_parts) == 5: # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1' elif (
len(path_parts) == 5
): # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
workspace_slug = path_parts[0] workspace_slug = path_parts[0]
repo_slug = path_parts[2] repo_slug = path_parts[2]
pr_number = int(path_parts[4]) pr_number = int(path_parts[4])
else: else:
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL") raise ValueError(
"The provided URL does not appear to be a Azure DevOps PR URL"
)
return workspace_slug, repo_slug, pr_number return workspace_slug, repo_slug, pr_number
@ -575,12 +661,16 @@ class AzureDevopsProvider(GitProvider):
# try to use azure default credentials # try to use azure default credentials
# see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python # see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
# for usage and env var configuration of user-assigned managed identity, local machine auth etc. # for usage and env var configuration of user-assigned managed identity, local machine auth etc.
get_logger().info("No PAT found in settings, trying to use Azure Default Credentials.") get_logger().info(
"No PAT found in settings, trying to use Azure Default Credentials."
)
credentials = DefaultAzureCredential() credentials = DefaultAzureCredential()
accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID) accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID)
auth_token = accessToken.token auth_token = accessToken.token
except Exception as e: except Exception as e:
get_logger().error(f"No PAT found in settings, and Azure Default Authentication failed, error: {e}") get_logger().error(
f"No PAT found in settings, and Azure Default Authentication failed, error: {e}"
)
raise raise
credentials = BasicAuthentication("", auth_token) credentials = BasicAuthentication("", auth_token)

View File

@ -52,13 +52,19 @@ class BitbucketProvider(GitProvider):
self.git_files = None self.git_files = None
if pr_url: if pr_url:
self.set_pr(pr_url) self.set_pr(pr_url)
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"]["comments"]["href"] self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"][
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"]['self']['href'] "comments"
]["href"]
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"][
'self'
]['href']
def get_repo_settings(self): def get_repo_settings(self):
try: try:
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" url = (
f"{self.pr.destination_branch}/.pr_agent.toml") f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{self.pr.destination_branch}/.pr_agent.toml"
)
response = requests.request("GET", url, headers=self.headers) response = requests.request("GET", url, headers=self.headers)
if response.status_code == 404: # not found if response.status_code == 404: # not found
return "" return ""
@ -74,20 +80,27 @@ class BitbucketProvider(GitProvider):
post_parameters_list = [] post_parameters_list = []
for suggestion in code_suggestions: for suggestion in code_suggestions:
body = suggestion["body"] body = suggestion["body"]
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code original_suggestion = suggestion.get(
'original_suggestion', None
) # needed for diff code
if original_suggestion: if original_suggestion:
try: try:
existing_code = original_suggestion['existing_code'].rstrip() + "\n" existing_code = original_suggestion['existing_code'].rstrip() + "\n"
improved_code = original_suggestion['improved_code'].rstrip() + "\n" improved_code = original_suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'), diff = difflib.unified_diff(
improved_code.split('\n'), n=999) existing_code.split('\n'), improved_code.split('\n'), n=999
)
patch_orig = "\n".join(diff) patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```" diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex: # replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL) body = re.sub(
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
)
except Exception as e: except Exception as e:
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}") get_logger().exception(
f"Bitbucket failed to get diff code for publishing, error: {e}"
)
continue continue
relevant_file = suggestion["relevant_file"] relevant_file = suggestion["relevant_file"]
@ -129,15 +142,22 @@ class BitbucketProvider(GitProvider):
self.publish_inline_comments(post_parameters_list) self.publish_inline_comments(post_parameters_list)
return True return True
except Exception as e: except Exception as e:
get_logger().error(f"Bitbucket failed to publish code suggestion, error: {e}") get_logger().error(
f"Bitbucket failed to publish code suggestion, error: {e}"
)
return False return False
def publish_file_comments(self, file_comments: list) -> bool: def publish_file_comments(self, file_comments: list) -> bool:
pass pass
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'publish_inline_comments', 'get_labels', 'gfm_markdown', if capability in [
'publish_file_comments']: 'get_issue_comments',
'publish_inline_comments',
'get_labels',
'gfm_markdown',
'publish_file_comments',
]:
return False return False
return True return True
@ -169,12 +189,14 @@ class BitbucketProvider(GitProvider):
names_original = [d.new.path for d in diffs_original] names_original = [d.new.path for d in diffs_original]
names_kept = [d.new.path for d in diffs] names_kept = [d.new.path for d in diffs]
names_filtered = list(set(names_original) - set(names_kept)) names_filtered = list(set(names_original) - set(names_kept))
get_logger().info(f"Filtered out [ignore] files for PR", extra={ get_logger().info(
f"Filtered out [ignore] files for PR",
extra={
'original_files': names_original, 'original_files': names_original,
'names_kept': names_kept, 'names_kept': names_kept,
'names_filtered': names_filtered 'names_filtered': names_filtered,
},
}) )
except Exception as e: except Exception as e:
pass pass
@ -189,20 +211,32 @@ class BitbucketProvider(GitProvider):
for encoding in encodings_to_try: for encoding in encodings_to_try:
try: try:
pr_patches = self.pr.diff(encoding=encoding) pr_patches = self.pr.diff(encoding=encoding)
get_logger().info(f"Successfully decoded PR patch with encoding {encoding}") get_logger().info(
f"Successfully decoded PR patch with encoding {encoding}"
)
break break
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
if pr_patches is None: if pr_patches is None:
raise ValueError(f"Failed to decode PR patch with encodings {encodings_to_try}") raise ValueError(
f"Failed to decode PR patch with encodings {encodings_to_try}"
)
diff_split = ["diff --git" + x for x in pr_patches.split("diff --git") if x.strip()] diff_split = [
"diff --git" + x for x in pr_patches.split("diff --git") if x.strip()
]
# filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs' # filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs'
if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split): if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split):
diff_split = [diff_split[i] for i in range(len(diff_split)) if diffs_original[i] in diffs] diff_split = [
diff_split[i]
for i in range(len(diff_split))
if diffs_original[i] in diffs
]
if len(diff_split) != len(diffs): if len(diff_split) != len(diffs):
get_logger().error(f"Error - failed to split the diff into {len(diffs)} parts") get_logger().error(
f"Error - failed to split the diff into {len(diffs)} parts"
)
return [] return []
# bitbucket diff has a header for each file, we need to remove it: # bitbucket diff has a header for each file, we need to remove it:
# "diff --git filename # "diff --git filename
@ -213,22 +247,34 @@ class BitbucketProvider(GitProvider):
# @@ -... @@" # @@ -... @@"
for i, _ in enumerate(diff_split): for i, _ in enumerate(diff_split):
diff_split_lines = diff_split[i].splitlines() diff_split_lines = diff_split[i].splitlines()
if (len(diff_split_lines) >= 6) and \ if (len(diff_split_lines) >= 6) and (
((diff_split_lines[2].startswith("---") and (
diff_split_lines[3].startswith("+++") and diff_split_lines[2].startswith("---")
diff_split_lines[4].startswith("@@")) or and diff_split_lines[3].startswith("+++")
(diff_split_lines[3].startswith("---") and # new or deleted file and diff_split_lines[4].startswith("@@")
diff_split_lines[4].startswith("+++") and )
diff_split_lines[5].startswith("@@"))): or (
diff_split_lines[3].startswith("---")
and diff_split_lines[4].startswith("+++") # new or deleted file
and diff_split_lines[5].startswith("@@")
)
):
diff_split[i] = "\n".join(diff_split_lines[4:]) diff_split[i] = "\n".join(diff_split_lines[4:])
else: else:
if diffs[i].data.get('lines_added', 0) == 0 and diffs[i].data.get('lines_removed', 0) == 0: if (
diffs[i].data.get('lines_added', 0) == 0
and diffs[i].data.get('lines_removed', 0) == 0
):
diff_split[i] = "" diff_split[i] = ""
elif len(diff_split_lines) <= 3: elif len(diff_split_lines) <= 3:
diff_split[i] = "" diff_split[i] = ""
get_logger().info(f"Disregarding empty diff for file {_gef_filename(diffs[i])}") get_logger().info(
f"Disregarding empty diff for file {_gef_filename(diffs[i])}"
)
else: else:
get_logger().warning(f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}") get_logger().warning(
f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}"
)
diff_split[i] = "" diff_split[i] = ""
invalid_files_names = [] invalid_files_names = []
@ -246,24 +292,32 @@ class BitbucketProvider(GitProvider):
if get_settings().get("bitbucket_app.avoid_full_files", False): if get_settings().get("bitbucket_app.avoid_full_files", False):
original_file_content_str = "" original_file_content_str = ""
new_file_content_str = "" new_file_content_str = ""
elif counter_valid < MAX_FILES_ALLOWED_FULL // 2: # factor 2 because bitbucket has limited API calls elif (
counter_valid < MAX_FILES_ALLOWED_FULL // 2
): # factor 2 because bitbucket has limited API calls
if diff.old.get_data("links"): if diff.old.get_data("links"):
original_file_content_str = self._get_pr_file_content( original_file_content_str = self._get_pr_file_content(
diff.old.get_data("links")['self']['href']) diff.old.get_data("links")['self']['href']
)
else: else:
original_file_content_str = "" original_file_content_str = ""
if diff.new.get_data("links"): if diff.new.get_data("links"):
new_file_content_str = self._get_pr_file_content(diff.new.get_data("links")['self']['href']) new_file_content_str = self._get_pr_file_content(
diff.new.get_data("links")['self']['href']
)
else: else:
new_file_content_str = "" new_file_content_str = ""
else: else:
if counter_valid == MAX_FILES_ALLOWED_FULL // 2: if counter_valid == MAX_FILES_ALLOWED_FULL // 2:
get_logger().info( get_logger().info(
f"Bitbucket too many files in PR, will avoid loading full content for rest of files") f"Bitbucket too many files in PR, will avoid loading full content for rest of files"
)
original_file_content_str = "" original_file_content_str = ""
new_file_content_str = "" new_file_content_str = ""
except Exception as e: except Exception as e:
get_logger().exception(f"Error - bitbucket failed to get file content, error: {e}") get_logger().exception(
f"Error - bitbucket failed to get file content, error: {e}"
)
original_file_content_str = "" original_file_content_str = ""
new_file_content_str = "" new_file_content_str = ""
@ -285,7 +339,9 @@ class BitbucketProvider(GitProvider):
diff_files.append(file_patch_canonic_structure) diff_files.append(file_patch_canonic_structure)
if invalid_files_names: if invalid_files_names:
get_logger().info(f"Disregarding files with invalid extensions:\n{invalid_files_names}") get_logger().info(
f"Disregarding files with invalid extensions:\n{invalid_files_names}"
)
self.diff_files = diff_files self.diff_files = diff_files
return diff_files return diff_files
@ -296,11 +352,14 @@ class BitbucketProvider(GitProvider):
def get_comment_url(self, comment): def get_comment_url(self, comment):
return comment.data['links']['html']['href'] return comment.data['links']['html']['href']
def publish_persistent_comment(self, pr_comment: str, def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str, initial_header: str,
update_header: bool = True, update_header: bool = True,
name='review', name='review',
final_update_message=True): final_update_message=True,
):
try: try:
for comment in self.pr.comments(): for comment in self.pr.comments():
body = comment.raw body = comment.raw
@ -309,15 +368,20 @@ class BitbucketProvider(GitProvider):
comment_url = self.get_comment_url(comment) comment_url = self.get_comment_url(comment)
if update_header: if update_header:
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n" updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
pr_comment_updated = pr_comment.replace(initial_header, updated_header) pr_comment_updated = pr_comment.replace(
initial_header, updated_header
)
else: else:
pr_comment_updated = pr_comment pr_comment_updated = pr_comment
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message") get_logger().info(
f"Persistent mode - updating comment {comment_url} to latest {name} message"
)
d = {"content": {"raw": pr_comment_updated}} d = {"content": {"raw": pr_comment_updated}}
response = comment._update_data(comment.put(None, data=d)) response = comment._update_data(comment.put(None, data=d))
if final_update_message: if final_update_message:
self.publish_comment( self.publish_comment(
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}") f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
)
return return
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}") get_logger().exception(f"Failed to update persistent review, error: {e}")
@ -326,7 +390,9 @@ class BitbucketProvider(GitProvider):
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress: if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}") get_logger().debug(
f"Skipping publish_comment for temporary comment: {pr_comment}"
)
return None return None
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length) pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length)
comment = self.pr.comment(pr_comment) comment = self.pr.comment(pr_comment)
@ -355,39 +421,58 @@ class BitbucketProvider(GitProvider):
get_logger().exception(f"Failed to remove comment, error: {e}") get_logger().exception(f"Failed to remove comment, error: {e}")
# function to create_inline_comment # function to create_inline_comment
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, def create_inline_comment(
absolute_position: int = None): self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
body = self.limit_output_characters(body, self.max_comment_length) body = self.limit_output_characters(body, self.max_comment_length)
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(), position, absolute_position = find_line_number_of_relevant_line_in_file(
self.get_diff_files(),
relevant_file.strip('`'), relevant_file.strip('`'),
relevant_line_in_file, relevant_line_in_file,
absolute_position) absolute_position,
)
if position == -1: if position == -1:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE" subject_type = "FILE"
else: else:
subject_type = "LINE" subject_type = "LINE"
path = relevant_file.strip() path = relevant_file.strip()
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {} return (
dict(body=body, path=path, position=absolute_position)
if subject_type == "LINE"
else {}
)
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None): def publish_inline_comment(
self, comment: str, from_line: int, file: str, original_suggestion=None
):
comment = self.limit_output_characters(comment, self.max_comment_length) comment = self.limit_output_characters(comment, self.max_comment_length)
payload = json.dumps({ payload = json.dumps(
{
"content": { "content": {
"raw": comment, "raw": comment,
}, },
"inline": { "inline": {"to": from_line, "path": file},
"to": from_line, }
"path": file )
},
})
response = requests.request( response = requests.request(
"POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers "POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers
) )
return response return response
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
if relevant_line_start == -1: if relevant_line_start == -1:
link = f"{self.pr_url}/#L{relevant_file}" link = f"{self.pr_url}/#L{relevant_file}"
else: else:
@ -402,8 +487,9 @@ class BitbucketProvider(GitProvider):
return "" return ""
diff_files = self.get_diff_files() diff_files = self.get_diff_files()
position, absolute_position = find_line_number_of_relevant_line_in_file \ position, absolute_position = find_line_number_of_relevant_line_in_file(
(diff_files, relevant_file, relevant_line_str) diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1 and self.pr_url: if absolute_position != -1 and self.pr_url:
link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}" link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}"
@ -417,12 +503,18 @@ class BitbucketProvider(GitProvider):
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
for comment in comments: for comment in comments:
if 'position' in comment: if 'position' in comment:
self.publish_inline_comment(comment['body'], comment['position'], comment['path']) self.publish_inline_comment(
comment['body'], comment['position'], comment['path']
)
elif 'start_line' in comment: # multi-line comment elif 'start_line' in comment: # multi-line comment
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452 # note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path']) self.publish_inline_comment(
comment['body'], comment['start_line'], comment['path']
)
elif 'line' in comment: # single-line comment elif 'line' in comment: # single-line comment
self.publish_inline_comment(comment['body'], comment['line'], comment['path']) self.publish_inline_comment(
comment['body'], comment['line'], comment['path']
)
else: else:
get_logger().error(f"Could not publish inline comment {comment}") get_logger().error(f"Could not publish inline comment {comment}")
@ -450,7 +542,9 @@ class BitbucketProvider(GitProvider):
"Bitbucket provider does not support issue comments yet" "Bitbucket provider does not support issue comments yet"
) )
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -495,8 +589,10 @@ class BitbucketProvider(GitProvider):
branch = self.pr.data["source"]["commit"]["hash"] branch = self.pr.data["source"]["commit"]["hash"]
elif branch == self.pr.destination_branch: elif branch == self.pr.destination_branch:
branch = self.pr.data["destination"]["commit"]["hash"] branch = self.pr.data["destination"]["commit"]["hash"]
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" url = (
f"{branch}/{file_path}") f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
f"{branch}/{file_path}"
)
response = requests.request("GET", url, headers=self.headers) response = requests.request("GET", url, headers=self.headers)
if response.status_code == 404: # not found if response.status_code == 404: # not found
return "" return ""
@ -505,23 +601,28 @@ class BitbucketProvider(GitProvider):
except Exception: except Exception:
return "" return ""
def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None: def create_or_update_pr_file(
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/") self, file_path: str, branch: str, contents="", message=""
) -> None:
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
if not message: if not message:
if contents: if contents:
message = f"Update {file_path}" message = f"Update {file_path}"
else: else:
message = f"Create {file_path}" message = f"Create {file_path}"
files = {file_path: contents} files = {file_path: contents}
data = { data = {"message": message, "branch": branch}
"message": message, headers = (
"branch": branch {'Authorization': self.headers['Authorization']}
} if 'Authorization' in self.headers
headers = {'Authorization': self.headers['Authorization']} if 'Authorization' in self.headers else {} else {}
)
try: try:
requests.request("POST", url, headers=headers, data=data, files=files) requests.request("POST", url, headers=headers, data=data, files=files)
except Exception: except Exception:
get_logger().exception(f"Failed to create empty file {file_path} in branch {branch}") get_logger().exception(
f"Failed to create empty file {file_path} in branch {branch}"
)
def _get_pr_file_content(self, remote_link: str): def _get_pr_file_content(self, remote_link: str):
try: try:
@ -538,16 +639,19 @@ class BitbucketProvider(GitProvider):
# bitbucket does not support labels # bitbucket does not support labels
def publish_description(self, pr_title: str, description: str): def publish_description(self, pr_title: str, description: str):
payload = json.dumps({ payload = json.dumps({"description": description, "title": pr_title})
"description": description,
"title": pr_title
}) response = requests.request(
"PUT",
response = requests.request("PUT", self.bitbucket_pull_request_api_url, headers=self.headers, data=payload) self.bitbucket_pull_request_api_url,
headers=self.headers,
data=payload,
)
try: try:
if response.status_code != 200: if response.status_code != 200:
get_logger().info(f"Failed to update description, error code: {response.status_code}") get_logger().info(
f"Failed to update description, error code: {response.status_code}"
)
except: except:
pass pass
return response return response

View File

@ -11,8 +11,7 @@ from requests.exceptions import HTTPError
from ..algo.git_patch_processing import decode_if_bytes from ..algo.git_patch_processing import decode_if_bytes
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.types import EDIT_TYPE, FilePatchInfo from ..algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.utils import (find_line_number_of_relevant_line_in_file, from ..algo.utils import find_line_number_of_relevant_line_in_file, load_large_diff
load_large_diff)
from ..config_loader import get_settings from ..config_loader import get_settings
from ..log import get_logger from ..log import get_logger
from .git_provider import GitProvider from .git_provider import GitProvider
@ -20,7 +19,9 @@ from .git_provider import GitProvider
class BitbucketServerProvider(GitProvider): class BitbucketServerProvider(GitProvider):
def __init__( def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False, self,
pr_url: Optional[str] = None,
incremental: Optional[bool] = False,
bitbucket_client: Optional[Bitbucket] = None, bitbucket_client: Optional[Bitbucket] = None,
): ):
self.bitbucket_server_url = None self.bitbucket_server_url = None
@ -36,11 +37,16 @@ class BitbucketServerProvider(GitProvider):
self.bitbucket_pull_request_api_url = pr_url self.bitbucket_pull_request_api_url = pr_url
self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url) self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url)
self.bitbucket_client = bitbucket_client or Bitbucket(url=self.bitbucket_server_url, self.bitbucket_client = bitbucket_client or Bitbucket(
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN", url=self.bitbucket_server_url,
None)) token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN", None),
)
try: try:
self.bitbucket_api_version = parse_version(self.bitbucket_client.get("rest/api/1.0/application-properties").get('version')) self.bitbucket_api_version = parse_version(
self.bitbucket_client.get("rest/api/1.0/application-properties").get(
'version'
)
)
except Exception: except Exception:
self.bitbucket_api_version = None self.bitbucket_api_version = None
@ -49,7 +55,12 @@ class BitbucketServerProvider(GitProvider):
def get_repo_settings(self): def get_repo_settings(self):
try: try:
content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch()) content = self.bitbucket_client.get_content_of_file(
self.workspace_slug,
self.repo_slug,
".pr_agent.toml",
self.get_pr_branch(),
)
return content return content
except Exception as e: except Exception as e:
@ -70,20 +81,27 @@ class BitbucketServerProvider(GitProvider):
post_parameters_list = [] post_parameters_list = []
for suggestion in code_suggestions: for suggestion in code_suggestions:
body = suggestion["body"] body = suggestion["body"]
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code original_suggestion = suggestion.get(
'original_suggestion', None
) # needed for diff code
if original_suggestion: if original_suggestion:
try: try:
existing_code = original_suggestion['existing_code'].rstrip() + "\n" existing_code = original_suggestion['existing_code'].rstrip() + "\n"
improved_code = original_suggestion['improved_code'].rstrip() + "\n" improved_code = original_suggestion['improved_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'), diff = difflib.unified_diff(
improved_code.split('\n'), n=999) existing_code.split('\n'), improved_code.split('\n'), n=999
)
patch_orig = "\n".join(diff) patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```" diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex: # replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL) body = re.sub(
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
)
except Exception as e: except Exception as e:
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}") get_logger().exception(
f"Bitbucket failed to get diff code for publishing, error: {e}"
)
continue continue
relevant_file = suggestion["relevant_file"] relevant_file = suggestion["relevant_file"]
relevant_lines_start = suggestion["relevant_lines_start"] relevant_lines_start = suggestion["relevant_lines_start"]
@ -134,7 +152,12 @@ class BitbucketServerProvider(GitProvider):
pass pass
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'get_labels', 'gfm_markdown', 'publish_file_comments']: if capability in [
'get_issue_comments',
'get_labels',
'gfm_markdown',
'publish_file_comments',
]:
return False return False
return True return True
@ -145,23 +168,28 @@ class BitbucketServerProvider(GitProvider):
def get_file(self, path: str, commit_id: str): def get_file(self, path: str, commit_id: str):
file_content = "" file_content = ""
try: try:
file_content = self.bitbucket_client.get_content_of_file(self.workspace_slug, file_content = self.bitbucket_client.get_content_of_file(
self.repo_slug, self.workspace_slug, self.repo_slug, path, commit_id
path, )
commit_id)
except HTTPError as e: except HTTPError as e:
get_logger().debug(f"File {path} not found at commit id: {commit_id}") get_logger().debug(f"File {path} not found at commit id: {commit_id}")
return file_content return file_content
def get_files(self): def get_files(self):
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num) changes = self.bitbucket_client.get_pull_requests_changes(
self.workspace_slug, self.repo_slug, self.pr_num
)
diffstat = [change["path"]['toString'] for change in changes] diffstat = [change["path"]['toString'] for change in changes]
return diffstat return diffstat
# gets the best common ancestor: https://git-scm.com/docs/git-merge-base # gets the best common ancestor: https://git-scm.com/docs/git-merge-base
@staticmethod @staticmethod
def get_best_common_ancestor(source_commits_list, destination_commits_list, guaranteed_common_ancestor) -> str: def get_best_common_ancestor(
destination_commit_hashes = {commit['id'] for commit in destination_commits_list} | {guaranteed_common_ancestor} source_commits_list, destination_commits_list, guaranteed_common_ancestor
) -> str:
destination_commit_hashes = {
commit['id'] for commit in destination_commits_list
} | {guaranteed_common_ancestor}
for commit in source_commits_list: for commit in source_commits_list:
for parent_commit in commit['parents']: for parent_commit in commit['parents']:
@ -177,37 +205,55 @@ class BitbucketServerProvider(GitProvider):
head_sha = self.pr.fromRef['latestCommit'] head_sha = self.pr.fromRef['latestCommit']
# if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation # if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("8.16"): if (
self.bitbucket_api_version is not None
and self.bitbucket_api_version >= parse_version("8.16")
):
try: try:
base_sha = self.bitbucket_client.get(self._get_merge_base())['id'] base_sha = self.bitbucket_client.get(self._get_merge_base())['id']
except Exception as e: except Exception as e:
get_logger().error(f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}") get_logger().error(
f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}"
)
raise e raise e
else: else:
source_commits_list = list(self.bitbucket_client.get_pull_requests_commits( source_commits_list = list(
self.workspace_slug, self.bitbucket_client.get_pull_requests_commits(
self.repo_slug, self.workspace_slug, self.repo_slug, self.pr_num
self.pr_num )
)) )
# if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor # if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor
base_sha = source_commits_list[-1]['parents'][0]['id'] base_sha = source_commits_list[-1]['parents'][0]['id']
# if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha # if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("7.0"): if (
self.bitbucket_api_version is not None
and self.bitbucket_api_version >= parse_version("7.0")
):
try: try:
destination_commits = list( destination_commits = list(
self.bitbucket_client.get_commits(self.workspace_slug, self.repo_slug, base_sha, self.bitbucket_client.get_commits(
self.pr.toRef['latestCommit'])) self.workspace_slug,
base_sha = self.get_best_common_ancestor(source_commits_list, destination_commits, base_sha) self.repo_slug,
base_sha,
self.pr.toRef['latestCommit'],
)
)
base_sha = self.get_best_common_ancestor(
source_commits_list, destination_commits, base_sha
)
except Exception as e: except Exception as e:
get_logger().error( get_logger().error(
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}") f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}"
)
raise e raise e
diff_files = [] diff_files = []
original_file_content_str = "" original_file_content_str = ""
new_file_content_str = "" new_file_content_str = ""
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num) changes = self.bitbucket_client.get_pull_requests_changes(
self.workspace_slug, self.repo_slug, self.pr_num
)
for change in changes: for change in changes:
file_path = change['path']['toString'] file_path = change['path']['toString']
if not is_valid_file(file_path.split("/")[-1]): if not is_valid_file(file_path.split("/")[-1]):
@ -224,17 +270,26 @@ class BitbucketServerProvider(GitProvider):
edit_type = EDIT_TYPE.DELETED edit_type = EDIT_TYPE.DELETED
new_file_content_str = "" new_file_content_str = ""
original_file_content_str = self.get_file(file_path, base_sha) original_file_content_str = self.get_file(file_path, base_sha)
original_file_content_str = decode_if_bytes(original_file_content_str) original_file_content_str = decode_if_bytes(
original_file_content_str
)
case 'RENAME': case 'RENAME':
edit_type = EDIT_TYPE.RENAMED edit_type = EDIT_TYPE.RENAMED
case _: case _:
edit_type = EDIT_TYPE.MODIFIED edit_type = EDIT_TYPE.MODIFIED
original_file_content_str = self.get_file(file_path, base_sha) original_file_content_str = self.get_file(file_path, base_sha)
original_file_content_str = decode_if_bytes(original_file_content_str) original_file_content_str = decode_if_bytes(
original_file_content_str
)
new_file_content_str = self.get_file(file_path, head_sha) new_file_content_str = self.get_file(file_path, head_sha)
new_file_content_str = decode_if_bytes(new_file_content_str) new_file_content_str = decode_if_bytes(new_file_content_str)
patch = load_large_diff(file_path, new_file_content_str, original_file_content_str, show_warning=False) patch = load_large_diff(
file_path,
new_file_content_str,
original_file_content_str,
show_warning=False,
)
diff_files.append( diff_files.append(
FilePatchInfo( FilePatchInfo(
@ -251,7 +306,9 @@ class BitbucketServerProvider(GitProvider):
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if not is_temporary: if not is_temporary:
self.bitbucket_client.add_pull_request_comment(self.workspace_slug, self.repo_slug, self.pr_num, pr_comment) self.bitbucket_client.add_pull_request_comment(
self.workspace_slug, self.repo_slug, self.pr_num, pr_comment
)
def remove_initial_comment(self): def remove_initial_comment(self):
try: try:
@ -264,25 +321,37 @@ class BitbucketServerProvider(GitProvider):
pass pass
# function to create_inline_comment # function to create_inline_comment
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, def create_inline_comment(
absolute_position: int = None): self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
position, absolute_position = find_line_number_of_relevant_line_in_file( position, absolute_position = find_line_number_of_relevant_line_in_file(
self.get_diff_files(), self.get_diff_files(),
relevant_file.strip('`'), relevant_file.strip('`'),
relevant_line_in_file, relevant_line_in_file,
absolute_position absolute_position,
) )
if position == -1: if position == -1:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE" subject_type = "FILE"
else: else:
subject_type = "LINE" subject_type = "LINE"
path = relevant_file.strip() path = relevant_file.strip()
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {} return (
dict(body=body, path=path, position=absolute_position)
if subject_type == "LINE"
else {}
)
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None): def publish_inline_comment(
self, comment: str, from_line: int, file: str, original_suggestion=None
):
payload = { payload = {
"text": comment, "text": comment,
"severity": "NORMAL", "severity": "NORMAL",
@ -291,17 +360,24 @@ class BitbucketServerProvider(GitProvider):
"path": file, "path": file,
"lineType": "ADDED", "lineType": "ADDED",
"line": from_line, "line": from_line,
"fileType": "TO" "fileType": "TO",
} },
} }
try: try:
self.bitbucket_client.post(self._get_pr_comments_path(), data=payload) self.bitbucket_client.post(self._get_pr_comments_path(), data=payload)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}") get_logger().error(
f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}"
)
raise e raise e
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
if relevant_line_start == -1: if relevant_line_start == -1:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}" link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}"
else: else:
@ -316,8 +392,9 @@ class BitbucketServerProvider(GitProvider):
return "" return ""
diff_files = self.get_diff_files() diff_files = self.get_diff_files()
position, absolute_position = find_line_number_of_relevant_line_in_file \ position, absolute_position = find_line_number_of_relevant_line_in_file(
(diff_files, relevant_file, relevant_line_str) diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1: if absolute_position != -1:
if self.pr: if self.pr:
@ -325,29 +402,41 @@ class BitbucketServerProvider(GitProvider):
return link return link
else: else:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}' since PR not set") get_logger().info(
f"Failed adding line link to '{relevant_file}' since PR not set"
)
else: else:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}' since position not found") get_logger().info(
f"Failed adding line link to '{relevant_file}' since position not found"
)
if absolute_position != -1 and self.pr_url: if absolute_position != -1 and self.pr_url:
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}" link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
return link return link
except Exception as e: except Exception as e:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Failed adding line link to '{relevant_file}', error: {e}") get_logger().info(
f"Failed adding line link to '{relevant_file}', error: {e}"
)
return "" return ""
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
for comment in comments: for comment in comments:
if 'position' in comment: if 'position' in comment:
self.publish_inline_comment(comment['body'], comment['position'], comment['path']) self.publish_inline_comment(
comment['body'], comment['position'], comment['path']
)
elif 'start_line' in comment: # multi-line comment elif 'start_line' in comment: # multi-line comment
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452 # note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path']) self.publish_inline_comment(
comment['body'], comment['start_line'], comment['path']
)
elif 'line' in comment: # single-line comment elif 'line' in comment: # single-line comment
self.publish_inline_comment(comment['body'], comment['line'], comment['path']) self.publish_inline_comment(
comment['body'], comment['line'], comment['path']
)
else: else:
get_logger().error(f"Could not publish inline comment: {comment}") get_logger().error(f"Could not publish inline comment: {comment}")
@ -377,7 +466,9 @@ class BitbucketServerProvider(GitProvider):
"Bitbucket provider does not support issue comments yet" "Bitbucket provider does not support issue comments yet"
) )
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -411,14 +502,20 @@ class BitbucketServerProvider(GitProvider):
users_index = -1 users_index = -1
if projects_index == -1 and users_index == -1: if projects_index == -1 and users_index == -1:
raise ValueError(f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL") raise ValueError(
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
)
if projects_index != -1: if projects_index != -1:
path_parts = path_parts[projects_index:] path_parts = path_parts[projects_index:]
else: else:
path_parts = path_parts[users_index:] path_parts = path_parts[users_index:]
if len(path_parts) < 6 or path_parts[2] != "repos" or path_parts[4] != "pull-requests": if (
len(path_parts) < 6
or path_parts[2] != "repos"
or path_parts[4] != "pull-requests"
):
raise ValueError( raise ValueError(
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL" f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
) )
@ -430,19 +527,24 @@ class BitbucketServerProvider(GitProvider):
try: try:
pr_number = int(path_parts[5]) pr_number = int(path_parts[5])
except ValueError as e: except ValueError as e:
raise ValueError(f"Unable to convert PR number '{path_parts[5]}' to integer") from e raise ValueError(
f"Unable to convert PR number '{path_parts[5]}' to integer"
) from e
return workspace_slug, repo_slug, pr_number return workspace_slug, repo_slug, pr_number
def _get_repo(self): def _get_repo(self):
if self.repo is None: if self.repo is None:
self.repo = self.bitbucket_client.get_repo(self.workspace_slug, self.repo_slug) self.repo = self.bitbucket_client.get_repo(
self.workspace_slug, self.repo_slug
)
return self.repo return self.repo
def _get_pr(self): def _get_pr(self):
try: try:
pr = self.bitbucket_client.get_pull_request(self.workspace_slug, self.repo_slug, pr = self.bitbucket_client.get_pull_request(
pull_request_id=self.pr_num) self.workspace_slug, self.repo_slug, pull_request_id=self.pr_num
)
return type('new_dict', (object,), pr) return type('new_dict', (object,), pr)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to get pull request, error: {e}") get_logger().error(f"Failed to get pull request, error: {e}")
@ -460,10 +562,12 @@ class BitbucketServerProvider(GitProvider):
"version": self.pr.version, "version": self.pr.version,
"description": description, "description": description,
"title": pr_title, "title": pr_title,
"reviewers": self.pr.reviewers # needs to be sent otherwise gets wiped "reviewers": self.pr.reviewers, # needs to be sent otherwise gets wiped
} }
try: try:
self.bitbucket_client.update_pull_request(self.workspace_slug, self.repo_slug, str(self.pr_num), payload) self.bitbucket_client.update_pull_request(
self.workspace_slug, self.repo_slug, str(self.pr_num), payload
)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to update pull request, error: {e}") get_logger().error(f"Failed to update pull request, error: {e}")
raise e raise e

View File

@ -31,7 +31,9 @@ class CodeCommitPullRequestResponse:
self.targets = [] self.targets = []
for target in json.get("pullRequestTargets", []): for target in json.get("pullRequestTargets", []):
self.targets.append(CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target)) self.targets.append(
CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target)
)
class CodeCommitPullRequestTarget: class CodeCommitPullRequestTarget:
""" """
@ -65,7 +67,9 @@ class CodeCommitClient:
except Exception as e: except Exception as e:
raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e
def get_differences(self, repo_name: int, destination_commit: str, source_commit: str): def get_differences(
self, repo_name: int, destination_commit: str, source_commit: str
):
""" """
Get the differences between two commits in CodeCommit. Get the differences between two commits in CodeCommit.
@ -96,17 +100,25 @@ class CodeCommitClient:
differences.extend(page.get("differences", [])) differences.extend(page.get("differences", []))
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException': if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}") from e raise ValueError(
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}"
) from e
raise ValueError(
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
) from e
except Exception as e: except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e raise ValueError(
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
) from e
output = [] output = []
for json in differences: for json in differences:
output.append(CodeCommitDifferencesResponse(json)) output.append(CodeCommitDifferencesResponse(json))
return output return output
def get_file(self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False): def get_file(
self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False
):
""" """
Retrieve a file from CodeCommit. Retrieve a file from CodeCommit.
@ -129,16 +141,24 @@ class CodeCommitClient:
self._connect_boto_client() self._connect_boto_client()
try: try:
response = self.boto_client.get_file(repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path) response = self.boto_client.get_file(
repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path
)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException': if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e raise ValueError(
f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
) from e
# if the file does not exist, but is flagged as optional, then return an empty string # if the file does not exist, but is flagged as optional, then return an empty string
if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException': if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException':
return "" return ""
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e raise ValueError(
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
) from e
except Exception as e: except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e raise ValueError(
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
) from e
if "fileContent" not in response: if "fileContent" not in response:
raise ValueError(f"File content is empty for file: {file_path}") raise ValueError(f"File content is empty for file: {file_path}")
@ -166,10 +186,16 @@ class CodeCommitClient:
response = self.boto_client.get_pull_request(pullRequestId=str(pr_number)) response = self.boto_client.get_pull_request(pullRequestId=str(pr_number))
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException': if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}") from e raise ValueError(
f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}"
) from e
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException': if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e raise ValueError(
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}: boto client error") from e f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
) from e
raise ValueError(
f"CodeCommit cannot retrieve PR: {pr_number}: boto client error"
) from e
except Exception as e: except Exception as e:
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e
@ -200,22 +226,37 @@ class CodeCommitClient:
self._connect_boto_client() self._connect_boto_client()
try: try:
self.boto_client.update_pull_request_title(pullRequestId=str(pr_number), title=pr_title) self.boto_client.update_pull_request_title(
self.boto_client.update_pull_request_description(pullRequestId=str(pr_number), description=pr_body) pullRequestId=str(pr_number), title=pr_title
)
self.boto_client.update_pull_request_description(
pullRequestId=str(pr_number), description=pr_body
)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException': if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"PR number does not exist: {pr_number}") from e raise ValueError(f"PR number does not exist: {pr_number}") from e
if e.response["Error"]["Code"] == 'InvalidTitleException': if e.response["Error"]["Code"] == 'InvalidTitleException':
raise ValueError(f"Invalid title for PR number: {pr_number}") from e raise ValueError(f"Invalid title for PR number: {pr_number}") from e
if e.response["Error"]["Code"] == 'InvalidDescriptionException': if e.response["Error"]["Code"] == 'InvalidDescriptionException':
raise ValueError(f"Invalid description for PR number: {pr_number}") from e raise ValueError(
f"Invalid description for PR number: {pr_number}"
) from e
if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException': if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException':
raise ValueError(f"PR is already closed: PR number: {pr_number}") from e raise ValueError(f"PR is already closed: PR number: {pr_number}") from e
raise ValueError(f"Boto3 client error calling publish_description") from e raise ValueError(f"Boto3 client error calling publish_description") from e
except Exception as e: except Exception as e:
raise ValueError(f"Error calling publish_description") from e raise ValueError(f"Error calling publish_description") from e
def publish_comment(self, repo_name: str, pr_number: int, destination_commit: str, source_commit: str, comment: str, annotation_file: str = None, annotation_line: int = None): def publish_comment(
self,
repo_name: str,
pr_number: int,
destination_commit: str,
source_commit: str,
comment: str,
annotation_file: str = None,
annotation_line: int = None,
):
""" """
Publish a comment to a pull request Publish a comment to a pull request
@ -272,6 +313,8 @@ class CodeCommitClient:
raise ValueError(f"Repository does not exist: {repo_name}") from e raise ValueError(f"Repository does not exist: {repo_name}") from e
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException': if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
raise ValueError(f"PR number does not exist: {pr_number}") from e raise ValueError(f"PR number does not exist: {pr_number}") from e
raise ValueError(f"Boto3 client error calling post_comment_for_pull_request") from e raise ValueError(
f"Boto3 client error calling post_comment_for_pull_request"
) from e
except Exception as e: except Exception as e:
raise ValueError(f"Error calling post_comment_for_pull_request") from e raise ValueError(f"Error calling post_comment_for_pull_request") from e

View File

@ -55,7 +55,9 @@ class CodeCommitProvider(GitProvider):
This class implements the GitProvider interface for AWS CodeCommit repositories. This class implements the GitProvider interface for AWS CodeCommit repositories.
""" """
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
):
self.codecommit_client = CodeCommitClient() self.codecommit_client = CodeCommitClient()
self.aws_client = None self.aws_client = None
self.repo_name = None self.repo_name = None
@ -76,7 +78,7 @@ class CodeCommitProvider(GitProvider):
"create_inline_comment", "create_inline_comment",
"publish_inline_comments", "publish_inline_comments",
"get_labels", "get_labels",
"gfm_markdown" "gfm_markdown",
]: ]:
return False return False
return True return True
@ -91,13 +93,19 @@ class CodeCommitProvider(GitProvider):
return self.git_files return self.git_files
self.git_files = [] self.git_files = []
differences = self.codecommit_client.get_differences(self.repo_name, self.pr.destination_commit, self.pr.source_commit) differences = self.codecommit_client.get_differences(
self.repo_name, self.pr.destination_commit, self.pr.source_commit
)
for item in differences: for item in differences:
self.git_files.append(CodeCommitFile(item.before_blob_path, self.git_files.append(
CodeCommitFile(
item.before_blob_path,
item.before_blob_id, item.before_blob_id,
item.after_blob_path, item.after_blob_path,
item.after_blob_id, item.after_blob_id,
CodeCommitProvider._get_edit_type(item.change_type))) CodeCommitProvider._get_edit_type(item.change_type),
)
)
return self.git_files return self.git_files
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
@ -121,21 +129,28 @@ class CodeCommitProvider(GitProvider):
if diff_item.a_blob_id is not None: if diff_item.a_blob_id is not None:
patch_filename = diff_item.a_path patch_filename = diff_item.a_path
original_file_content_str = self.codecommit_client.get_file( original_file_content_str = self.codecommit_client.get_file(
self.repo_name, diff_item.a_path, self.pr.destination_commit) self.repo_name, diff_item.a_path, self.pr.destination_commit
)
if isinstance(original_file_content_str, (bytes, bytearray)): if isinstance(original_file_content_str, (bytes, bytearray)):
original_file_content_str = original_file_content_str.decode("utf-8") original_file_content_str = original_file_content_str.decode(
"utf-8"
)
else: else:
original_file_content_str = "" original_file_content_str = ""
if diff_item.b_blob_id is not None: if diff_item.b_blob_id is not None:
patch_filename = diff_item.b_path patch_filename = diff_item.b_path
new_file_content_str = self.codecommit_client.get_file(self.repo_name, diff_item.b_path, self.pr.source_commit) new_file_content_str = self.codecommit_client.get_file(
self.repo_name, diff_item.b_path, self.pr.source_commit
)
if isinstance(new_file_content_str, (bytes, bytearray)): if isinstance(new_file_content_str, (bytes, bytearray)):
new_file_content_str = new_file_content_str.decode("utf-8") new_file_content_str = new_file_content_str.decode("utf-8")
else: else:
new_file_content_str = "" new_file_content_str = ""
patch = load_large_diff(patch_filename, new_file_content_str, original_file_content_str) patch = load_large_diff(
patch_filename, new_file_content_str, original_file_content_str
)
# Store the diffs as a list of FilePatchInfo objects # Store the diffs as a list of FilePatchInfo objects
info = FilePatchInfo( info = FilePatchInfo(
@ -164,7 +179,9 @@ class CodeCommitProvider(GitProvider):
pr_body=CodeCommitProvider._add_additional_newlines(pr_body), pr_body=CodeCommitProvider._add_additional_newlines(pr_body),
) )
except Exception as e: except Exception as e:
raise ValueError(f"CodeCommit Cannot publish description for PR: {self.pr_num}") from e raise ValueError(
f"CodeCommit Cannot publish description for PR: {self.pr_num}"
) from e
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary: if is_temporary:
@ -183,19 +200,28 @@ class CodeCommitProvider(GitProvider):
comment=pr_comment, comment=pr_comment,
) )
except Exception as e: except Exception as e:
raise ValueError(f"CodeCommit Cannot publish comment for PR: {self.pr_num}") from e raise ValueError(
f"CodeCommit Cannot publish comment for PR: {self.pr_num}"
) from e
def publish_code_suggestions(self, code_suggestions: list) -> bool: def publish_code_suggestions(self, code_suggestions: list) -> bool:
counter = 1 counter = 1
for suggestion in code_suggestions: for suggestion in code_suggestions:
# Verify that each suggestion has the required keys # Verify that each suggestion has the required keys
if not all(key in suggestion for key in ["body", "relevant_file", "relevant_lines_start"]): if not all(
get_logger().warning(f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys") key in suggestion
for key in ["body", "relevant_file", "relevant_lines_start"]
):
get_logger().warning(
f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys"
)
continue continue
# Publish the code suggestion to CodeCommit # Publish the code suggestion to CodeCommit
try: try:
get_logger().debug(f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}") get_logger().debug(
f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}"
)
self.codecommit_client.publish_comment( self.codecommit_client.publish_comment(
repo_name=self.repo_name, repo_name=self.repo_name,
pr_number=self.pr_num, pr_number=self.pr_num,
@ -206,7 +232,9 @@ class CodeCommitProvider(GitProvider):
annotation_line=suggestion["relevant_lines_start"], annotation_line=suggestion["relevant_lines_start"],
) )
except Exception as e: except Exception as e:
raise ValueError(f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}") from e raise ValueError(
f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}"
) from e
counter += 1 counter += 1
@ -227,12 +255,22 @@ class CodeCommitProvider(GitProvider):
def remove_comment(self, comment): def remove_comment(self, comment):
return "" # not implemented yet return "" # not implemented yet
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet") raise NotImplementedError(
"CodeCommit provider does not support publishing inline comments yet"
)
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet") raise NotImplementedError(
"CodeCommit provider does not support publishing inline comments yet"
)
def get_title(self): def get_title(self):
return self.pr.title return self.pr.title
@ -270,7 +308,9 @@ class CodeCommitProvider(GitProvider):
# We build that language->extension dictionary here in main_extensions_flat. # We build that language->extension dictionary here in main_extensions_flat.
main_extensions_flat = {} main_extensions_flat = {}
language_extension_map_org = get_settings().language_extension_map_org language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} language_extension_map = {
k.lower(): v for k, v in language_extension_map_org.items()
}
for language, extensions in language_extension_map.items(): for language, extensions in language_extension_map.items():
for ext in extensions: for ext in extensions:
main_extensions_flat[ext] = language main_extensions_flat[ext] = language
@ -292,14 +332,20 @@ class CodeCommitProvider(GitProvider):
return -1 # not implemented yet return -1 # not implemented yet
def get_issue_comments(self): def get_issue_comments(self):
raise NotImplementedError("CodeCommit provider does not support issue comments yet") raise NotImplementedError(
"CodeCommit provider does not support issue comments yet"
)
def get_repo_settings(self): def get_repo_settings(self):
# a local ".pr_agent.toml" settings file is optional # a local ".pr_agent.toml" settings file is optional
settings_filename = ".pr_agent.toml" settings_filename = ".pr_agent.toml"
return self.codecommit_client.get_file(self.repo_name, settings_filename, self.pr.source_commit, optional=True) return self.codecommit_client.get_file(
self.repo_name, settings_filename, self.pr.source_commit, optional=True
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
get_logger().info("CodeCommit provider does not support eyes reaction yet") get_logger().info("CodeCommit provider does not support eyes reaction yet")
return True return True
@ -323,7 +369,9 @@ class CodeCommitProvider(GitProvider):
parsed_url = urlparse(pr_url) parsed_url = urlparse(pr_url)
if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc): if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc):
raise ValueError(f"The provided URL is not a valid CodeCommit URL: {pr_url}") raise ValueError(
f"The provided URL is not a valid CodeCommit URL: {pr_url}"
)
path_parts = parsed_url.path.strip("/").split("/") path_parts = parsed_url.path.strip("/").split("/")
@ -334,14 +382,18 @@ class CodeCommitProvider(GitProvider):
or path_parts[2] != "repositories" or path_parts[2] != "repositories"
or path_parts[4] != "pull-requests" or path_parts[4] != "pull-requests"
): ):
raise ValueError(f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}") raise ValueError(
f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}"
)
repo_name = path_parts[3] repo_name = path_parts[3]
try: try:
pr_number = int(path_parts[5]) pr_number = int(path_parts[5])
except ValueError as e: except ValueError as e:
raise ValueError(f"Unable to convert PR number to integer: '{path_parts[5]}'") from e raise ValueError(
f"Unable to convert PR number to integer: '{path_parts[5]}'"
) from e
return repo_name, pr_number return repo_name, pr_number
@ -359,7 +411,12 @@ class CodeCommitProvider(GitProvider):
Returns: Returns:
- bool: True if the hostname is valid, False otherwise. - bool: True if the hostname is valid, False otherwise.
""" """
return re.match(r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname) is not None return (
re.match(
r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname
)
is not None
)
def _get_pr(self): def _get_pr(self):
response = self.codecommit_client.get_pr(self.repo_name, self.pr_num) response = self.codecommit_client.get_pr(self.repo_name, self.pr_num)

View File

@ -38,10 +38,7 @@ def clone(url, directory):
def fetch(url, refspec, cwd): def fetch(url, refspec, cwd):
get_logger().info("Fetching %s %s", url, refspec) get_logger().info("Fetching %s %s", url, refspec)
stdout = _call( stdout = _call('git', 'fetch', '--depth', '2', url, refspec, cwd=cwd)
'git', 'fetch', '--depth', '2', url, refspec,
cwd=cwd
)
get_logger().info(stdout) get_logger().info(stdout)
@ -75,10 +72,13 @@ def add_comment(url: urllib3.util.Url, refspec, message):
message = "'" + message.replace("'", "'\"'\"'") + "'" message = "'" + message.replace("'", "'\"'\"'") + "'"
return _call( return _call(
"ssh", "ssh",
"-p", str(url.port), "-p",
str(url.port),
f"{url.auth}@{url.host}", f"{url.auth}@{url.host}",
"gerrit", "review", "gerrit",
"--message", message, "review",
"--message",
message,
# "--code-review", score, # "--code-review", score,
f"{patchset},{changenum}", f"{patchset},{changenum}",
) )
@ -88,19 +88,23 @@ def list_comments(url: urllib3.util.Url, refspec):
*_, patchset, _ = refspec.rsplit("/") *_, patchset, _ = refspec.rsplit("/")
stdout = _call( stdout = _call(
"ssh", "ssh",
"-p", str(url.port), "-p",
str(url.port),
f"{url.auth}@{url.host}", f"{url.auth}@{url.host}",
"gerrit", "query", "gerrit",
"query",
"--comments", "--comments",
"--current-patch-set", patchset, "--current-patch-set",
"--format", "JSON", patchset,
"--format",
"JSON",
) )
change_set, *_ = stdout.splitlines() change_set, *_ = stdout.splitlines()
return json.loads(change_set)["currentPatchSet"]["comments"] return json.loads(change_set)["currentPatchSet"]["comments"]
def prepare_repo(url: urllib3.util.Url, project, refspec): def prepare_repo(url: urllib3.util.Url, project, refspec):
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}") repo_url = f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}"
directory = pathlib.Path(mkdtemp()) directory = pathlib.Path(mkdtemp())
clone(repo_url, directory), clone(repo_url, directory),
@ -114,18 +118,18 @@ def adopt_to_gerrit_message(message):
buf = [] buf = []
for line in lines: for line in lines:
# remove markdown formatting # remove markdown formatting
line = (line.replace("*", "") line = (
line.replace("*", "")
.replace("``", "`") .replace("``", "`")
.replace("<details>", "") .replace("<details>", "")
.replace("</details>", "") .replace("</details>", "")
.replace("<summary>", "") .replace("<summary>", "")
.replace("</summary>", "")) .replace("</summary>", "")
)
line = line.strip() line = line.strip()
if line.startswith('#'): if line.startswith('#'):
buf.append("\n" + buf.append("\n" + line.replace('#', '').removesuffix(":").strip() + ":")
line.replace('#', '').removesuffix(":").strip() +
":")
continue continue
elif line.startswith('-'): elif line.startswith('-'):
buf.append(line.removeprefix('-').strip()) buf.append(line.removeprefix('-').strip())
@ -136,10 +140,7 @@ def adopt_to_gerrit_message(message):
def add_suggestion(src_filename, context: str, start, end: int): def add_suggestion(src_filename, context: str, start, end: int):
with ( with NamedTemporaryFile("w", delete=False) as tmp, open(src_filename, "r") as src:
NamedTemporaryFile("w", delete=False) as tmp,
open(src_filename, "r") as src
):
lines = src.readlines() lines = src.readlines()
tmp.writelines(lines[: start - 1]) tmp.writelines(lines[: start - 1])
if context: if context:
@ -151,10 +152,8 @@ def add_suggestion(src_filename, context: str, start, end: int):
def upload_patch(patch, path): def upload_patch(patch, path):
patch_server_endpoint = get_settings().get( patch_server_endpoint = get_settings().get('gerrit.patch_server_endpoint')
'gerrit.patch_server_endpoint') patch_server_token = get_settings().get('gerrit.patch_server_token')
patch_server_token = get_settings().get(
'gerrit.patch_server_token')
response = requests.post( response = requests.post(
patch_server_endpoint, patch_server_endpoint,
@ -165,7 +164,7 @@ def upload_patch(patch, path):
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {patch_server_token}", "Authorization": f"Bearer {patch_server_token}",
} },
) )
response.raise_for_status() response.raise_for_status()
patch_server_endpoint = patch_server_endpoint.rstrip("/") patch_server_endpoint = patch_server_endpoint.rstrip("/")
@ -173,7 +172,6 @@ def upload_patch(patch, path):
class GerritProvider(GitProvider): class GerritProvider(GitProvider):
def __init__(self, key: str, incremental=False): def __init__(self, key: str, incremental=False):
self.project, self.refspec = key.split(':') self.project, self.refspec = key.split(':')
assert self.project, "Project name is required" assert self.project, "Project name is required"
@ -188,9 +186,7 @@ class GerritProvider(GitProvider):
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}" f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
) )
self.repo_path = prepare_repo( self.repo_path = prepare_repo(self.parsed_url, self.project, self.refspec)
self.parsed_url, self.project, self.refspec
)
self.repo = Repo(self.repo_path) self.repo = Repo(self.repo_path)
assert self.repo assert self.repo
self.pr_url = base_url self.pr_url = base_url
@ -210,15 +206,18 @@ class GerritProvider(GitProvider):
def get_pr_labels(self, update=False): def get_pr_labels(self, update=False):
raise NotImplementedError( raise NotImplementedError(
'Getting labels is not implemented for the gerrit provider') 'Getting labels is not implemented for the gerrit provider'
)
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False): def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
raise NotImplementedError( raise NotImplementedError(
'Adding reactions is not implemented for the gerrit provider') 'Adding reactions is not implemented for the gerrit provider'
)
def remove_reaction(self, issue_comment_id: int, reaction_id: int): def remove_reaction(self, issue_comment_id: int, reaction_id: int):
raise NotImplementedError( raise NotImplementedError(
'Removing reactions is not implemented for the gerrit provider') 'Removing reactions is not implemented for the gerrit provider'
)
def get_commit_messages(self): def get_commit_messages(self):
return [self.repo.head.commit.message] return [self.repo.head.commit.message]
@ -235,20 +234,21 @@ class GerritProvider(GitProvider):
diffs = self.repo.head.commit.diff( diffs = self.repo.head.commit.diff(
self.repo.head.commit.parents[0], # previous commit self.repo.head.commit.parents[0], # previous commit
create_patch=True, create_patch=True,
R=True R=True,
) )
diff_files = [] diff_files = []
for diff_item in diffs: for diff_item in diffs:
if diff_item.a_blob is not None: if diff_item.a_blob is not None:
original_file_content_str = ( original_file_content_str = diff_item.a_blob.data_stream.read().decode(
diff_item.a_blob.data_stream.read().decode('utf-8') 'utf-8'
) )
else: else:
original_file_content_str = "" # empty file original_file_content_str = "" # empty file
if diff_item.b_blob is not None: if diff_item.b_blob is not None:
new_file_content_str = diff_item.b_blob.data_stream.read(). \ new_file_content_str = diff_item.b_blob.data_stream.read().decode(
decode('utf-8') 'utf-8'
)
else: else:
new_file_content_str = "" # empty file new_file_content_str = "" # empty file
edit_type = EDIT_TYPE.MODIFIED edit_type = EDIT_TYPE.MODIFIED
@ -267,7 +267,7 @@ class GerritProvider(GitProvider):
edit_type=edit_type, edit_type=edit_type,
old_filename=None old_filename=None
if diff_item.a_path == diff_item.b_path if diff_item.a_path == diff_item.b_path
else diff_item.a_path else diff_item.a_path,
) )
) )
self.diff_files = diff_files self.diff_files = diff_files
@ -275,8 +275,7 @@ class GerritProvider(GitProvider):
def get_files(self): def get_files(self):
diff_index = self.repo.head.commit.diff( diff_index = self.repo.head.commit.diff(
self.repo.head.commit.parents[0], # previous commit self.repo.head.commit.parents[0], R=True # previous commit
R=True
) )
# Get the list of changed files # Get the list of changed files
diff_files = [item.a_path for item in diff_index] diff_files = [item.a_path for item in diff_index]
@ -288,16 +287,22 @@ class GerritProvider(GitProvider):
prioritisation. prioritisation.
""" """
# Get all files in repository # Get all files in repository
filepaths = [Path(item.path) for item in filepaths = [
self.repo.tree().traverse() if item.type == 'blob'] Path(item.path)
for item in self.repo.tree().traverse()
if item.type == 'blob'
]
# Identify language by file extension and count # Identify language by file extension and count
lang_count = Counter( lang_count = Counter(
ext.lstrip('.') for filepath in filepaths for ext in ext.lstrip('.')
[filepath.suffix.lower()]) for filepath in filepaths
for ext in [filepath.suffix.lower()]
)
# Convert counts to percentages # Convert counts to percentages
total_files = len(filepaths) total_files = len(filepaths)
lang_percentage = {lang: count / total_files * 100 for lang, count lang_percentage = {
in lang_count.items()} lang: count / total_files * 100 for lang, count in lang_count.items()
}
return lang_percentage return lang_percentage
def get_pr_description_full(self): def get_pr_description_full(self):
@ -312,7 +317,7 @@ class GerritProvider(GitProvider):
'create_inline_comment', 'create_inline_comment',
'publish_inline_comments', 'publish_inline_comments',
'get_labels', 'get_labels',
'gfm_markdown' 'gfm_markdown',
]: ]:
return False return False
return True return True
@ -331,14 +336,9 @@ class GerritProvider(GitProvider):
if is_code_context: if is_code_context:
context.append(line) context.append(line)
else: else:
description.append( description.append(line.replace('*', ''))
line.replace('*', '')
)
return ( return ('\n'.join(description), '\n'.join(context) + '\n' if context else '')
'\n'.join(description),
'\n'.join(context) + '\n' if context else ''
)
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list):
msg = [] msg = []
@ -372,15 +372,19 @@ class GerritProvider(GitProvider):
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError( raise NotImplementedError(
'Publishing inline comments is not implemented for the gerrit ' 'Publishing inline comments is not implemented for the gerrit ' 'provider'
'provider') )
def publish_inline_comment(self, body: str, relevant_file: str, def publish_inline_comment(
relevant_line_in_file: str, original_suggestion=None): self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
raise NotImplementedError( raise NotImplementedError(
'Publishing inline comments is not implemented for the gerrit ' 'Publishing inline comments is not implemented for the gerrit ' 'provider'
'provider') )
def publish_labels(self, labels): def publish_labels(self, labels):
# Not applicable to the local git provider, # Not applicable to the local git provider,

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED) # enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from typing import Optional from typing import Optional
@ -9,6 +10,7 @@ from utils.pr_agent.log import get_logger
MAX_FILES_ALLOWED_FULL = 50 MAX_FILES_ALLOWED_FULL = 50
class GitProvider(ABC): class GitProvider(ABC):
@abstractmethod @abstractmethod
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
@ -61,11 +63,18 @@ class GitProvider(ABC):
def reply_to_comment_from_comment_id(self, comment_id: int, body: str): def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
pass pass
def get_pr_description(self, full: bool = True, split_changes_walkthrough=False) -> str or tuple: def get_pr_description(
self, full: bool = True, split_changes_walkthrough=False
) -> str or tuple:
from utils.pr_agent.algo.utils import clip_tokens from utils.pr_agent.algo.utils import clip_tokens
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full() if full else self.get_user_description() max_tokens_description = get_settings().get(
"CONFIG.MAX_DESCRIPTION_TOKENS", None
)
description = (
self.get_pr_description_full() if full else self.get_user_description()
)
if split_changes_walkthrough: if split_changes_walkthrough:
description, files = process_description(description) description, files = process_description(description)
if max_tokens_description: if max_tokens_description:
@ -94,7 +103,9 @@ class GitProvider(ABC):
# return nothing (empty string) because it means there is no user description # return nothing (empty string) because it means there is no user description
user_description_header = "### **user description**" user_description_header = "### **user description**"
if user_description_header not in description_lowercase: if user_description_header not in description_lowercase:
get_logger().info(f"Existing description was generated by the pr-agent, but it doesn't contain a user description") get_logger().info(
f"Existing description was generated by the pr-agent, but it doesn't contain a user description"
)
return "" return ""
# otherwise, extract the original user description from the existing pr-agent description and return it # otherwise, extract the original user description from the existing pr-agent description and return it
@ -103,7 +114,9 @@ class GitProvider(ABC):
# the 'user description' is in the beginning. extract and return it # the 'user description' is in the beginning. extract and return it
possible_headers = self._possible_headers() possible_headers = self._possible_headers()
start_position = description_lowercase.find(user_description_header) + len(user_description_header) start_position = description_lowercase.find(user_description_header) + len(
user_description_header
)
end_position = len(description) end_position = len(description)
for header in possible_headers: # try to clip at the next header for header in possible_headers: # try to clip at the next header
if header != user_description_header and header in description_lowercase: if header != user_description_header and header in description_lowercase:
@ -115,20 +128,34 @@ class GitProvider(ABC):
else: else:
original_user_description = description.split("___")[0].strip() original_user_description = description.split("___")[0].strip()
if original_user_description.lower().startswith(user_description_header): if original_user_description.lower().startswith(user_description_header):
original_user_description = original_user_description[len(user_description_header):].strip() original_user_description = original_user_description[
len(user_description_header) :
].strip()
get_logger().info(f"Extracted user description from existing description", get_logger().info(
description=original_user_description) f"Extracted user description from existing description",
description=original_user_description,
)
self.user_description = original_user_description self.user_description = original_user_description
return original_user_description return original_user_description
def _possible_headers(self): def _possible_headers(self):
return ("### **user description**", "### **pr type**", "### **pr description**", "### **pr labels**", "### **type**", "### **description**", return (
"### **labels**", "### 🤖 generated by pr agent") "### **user description**",
"### **pr type**",
"### **pr description**",
"### **pr labels**",
"### **type**",
"### **description**",
"### **labels**",
"### 🤖 generated by pr agent",
)
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool: def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
possible_headers = self._possible_headers() possible_headers = self._possible_headers()
return any(description_lowercase.startswith(header) for header in possible_headers) return any(
description_lowercase.startswith(header) for header in possible_headers
)
@abstractmethod @abstractmethod
def get_repo_settings(self): def get_repo_settings(self):
@ -140,10 +167,17 @@ class GitProvider(ABC):
def get_pr_id(self): def get_pr_id(self):
return "" return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
return "" return ""
def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str: def get_lines_link_original_file(
self, filepath: str, component_range: Range
) -> str:
return "" return ""
#### comments operations #### #### comments operations ####
@ -151,18 +185,24 @@ class GitProvider(ABC):
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass pass
def publish_persistent_comment(self, pr_comment: str, def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str, initial_header: str,
update_header: bool = True, update_header: bool = True,
name='review', name='review',
final_update_message=True): final_update_message=True,
):
self.publish_comment(pr_comment) self.publish_comment(pr_comment)
def publish_persistent_comment_full(self, pr_comment: str, def publish_persistent_comment_full(
self,
pr_comment: str,
initial_header: str, initial_header: str,
update_header: bool = True, update_header: bool = True,
name='review', name='review',
final_update_message=True): final_update_message=True,
):
try: try:
prev_comments = list(self.get_issue_comments()) prev_comments = list(self.get_issue_comments())
for comment in prev_comments: for comment in prev_comments:
@ -171,29 +211,46 @@ class GitProvider(ABC):
comment_url = self.get_comment_url(comment) comment_url = self.get_comment_url(comment)
if update_header: if update_header:
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n" updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
pr_comment_updated = pr_comment.replace(initial_header, updated_header) pr_comment_updated = pr_comment.replace(
initial_header, updated_header
)
else: else:
pr_comment_updated = pr_comment pr_comment_updated = pr_comment
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message") get_logger().info(
f"Persistent mode - updating comment {comment_url} to latest {name} message"
)
# response = self.mr.notes.update(comment.id, {'body': pr_comment_updated}) # response = self.mr.notes.update(comment.id, {'body': pr_comment_updated})
self.edit_comment(comment, pr_comment_updated) self.edit_comment(comment, pr_comment_updated)
if final_update_message: if final_update_message:
self.publish_comment( self.publish_comment(
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}") f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
)
return return
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to update persistent review, error: {e}") get_logger().exception(f"Failed to update persistent review, error: {e}")
pass pass
self.publish_comment(pr_comment) self.publish_comment(pr_comment)
@abstractmethod @abstractmethod
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
pass pass
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, def create_inline_comment(
absolute_position: int = None): self,
raise NotImplementedError("This git provider does not support creating inline comments yet") body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
raise NotImplementedError(
"This git provider does not support creating inline comments yet"
)
@abstractmethod @abstractmethod
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
@ -227,7 +284,9 @@ class GitProvider(ABC):
pass pass
@abstractmethod @abstractmethod
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
pass pass
@abstractmethod @abstractmethod
@ -284,16 +343,23 @@ def get_main_pr_language(languages, files) -> str:
if not file: if not file:
continue continue
if isinstance(file, str): if isinstance(file, str):
file = FilePatchInfo(base_file=None, head_file=None, patch=None, filename=file) file = FilePatchInfo(
base_file=None, head_file=None, patch=None, filename=file
)
extension_list.append(file.filename.rsplit('.')[-1]) extension_list.append(file.filename.rsplit('.')[-1])
# get the most common extension # get the most common extension
most_common_extension = '.' + max(set(extension_list), key=extension_list.count) most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
try: try:
language_extension_map_org = get_settings().language_extension_map_org language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} language_extension_map = {
k.lower(): v for k, v in language_extension_map_org.items()
}
if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]: if (
top_language in language_extension_map
and most_common_extension in language_extension_map[top_language]
):
main_language_str = top_language main_language_str = top_language
else: else:
for language, extensions in language_extension_map.items(): for language, extensions in language_extension_map.items():
@ -332,8 +398,6 @@ def get_main_pr_language(languages, files) -> str:
return main_language_str return main_language_str
class IncrementalPR: class IncrementalPR:
def __init__(self, is_incremental: bool = False): def __init__(self, is_incremental: bool = False):
self.is_incremental = is_incremental self.is_incremental = is_incremental

View File

@ -18,14 +18,23 @@ from ..algo.file_filter import filter_ignored
from ..algo.git_patch_processing import extract_hunk_headers from ..algo.git_patch_processing import extract_hunk_headers
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.types import EDIT_TYPE from ..algo.types import EDIT_TYPE
from ..algo.utils import (PRReviewHeader, Range, clip_tokens, from ..algo.utils import (
PRReviewHeader,
Range,
clip_tokens,
find_line_number_of_relevant_line_in_file, find_line_number_of_relevant_line_in_file,
load_large_diff, set_file_languages) load_large_diff,
set_file_languages,
)
from ..config_loader import get_settings from ..config_loader import get_settings
from ..log import get_logger from ..log import get_logger
from ..servers.utils import RateLimitExceeded from ..servers.utils import RateLimitExceeded
from .git_provider import (MAX_FILES_ALLOWED_FULL, FilePatchInfo, GitProvider, from .git_provider import (
IncrementalPR) MAX_FILES_ALLOWED_FULL,
FilePatchInfo,
GitProvider,
IncrementalPR,
)
class GithubProvider(GitProvider): class GithubProvider(GitProvider):
@ -36,8 +45,14 @@ class GithubProvider(GitProvider):
except Exception: except Exception:
self.installation_id = None self.installation_id = None
self.max_comment_chars = 65000 self.max_comment_chars = 65000
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com" self.base_url = (
self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com" get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/")
) # "https://api.github.com"
self.base_url_html = (
self.base_url.split("api/")[0].rstrip("/")
if "api/" in self.base_url
else "https://github.com"
)
self.github_client = self._get_github_client() self.github_client = self._get_github_client()
self.repo = None self.repo = None
self.pr_num = None self.pr_num = None
@ -50,7 +65,9 @@ class GithubProvider(GitProvider):
self.set_pr(pr_url) self.set_pr(pr_url)
self.pr_commits = list(self.pr.get_commits()) self.pr_commits = list(self.pr.get_commits())
self.last_commit_id = self.pr_commits[-1] self.last_commit_id = self.pr_commits[-1]
self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object self.pr_url = (
self.get_pr_url()
) # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
else: else:
self.pr_commits = None self.pr_commits = None
@ -80,10 +97,14 @@ class GithubProvider(GitProvider):
# Get all files changed during the commit range # Get all files changed during the commit range
for commit in self.incremental.commits_range: for commit in self.incremental.commits_range:
if commit.commit.message.startswith(f"Merge branch '{self._get_repo().default_branch}'"): if commit.commit.message.startswith(
f"Merge branch '{self._get_repo().default_branch}'"
):
get_logger().info(f"Skipping merge commit {commit.commit.message}") get_logger().info(f"Skipping merge commit {commit.commit.message}")
continue continue
self.unreviewed_files_set.update({file.filename: file for file in commit.files}) self.unreviewed_files_set.update(
{file.filename: file for file in commit.files}
)
else: else:
get_logger().info("No previous review found, will review the entire PR") get_logger().info("No previous review found, will review the entire PR")
self.incremental.is_incremental = False self.incremental.is_incremental = False
@ -98,7 +119,11 @@ class GithubProvider(GitProvider):
else: else:
self.incremental.last_seen_commit = self.pr_commits[index] self.incremental.last_seen_commit = self.pr_commits[index]
break break
return self.pr_commits[first_new_commit_index:] if first_new_commit_index is not None else [] return (
self.pr_commits[first_new_commit_index:]
if first_new_commit_index is not None
else []
)
def get_previous_review(self, *, full: bool, incremental: bool): def get_previous_review(self, *, full: bool, incremental: bool):
if not (full or incremental): if not (full or incremental):
@ -138,8 +163,13 @@ class GithubProvider(GitProvider):
except Exception as e: except Exception as e:
return -1 return -1
@retry(exceptions=RateLimitExceeded, @retry(
tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) exceptions=RateLimitExceeded,
tries=get_settings().github.ratelimit_retries,
delay=2,
backoff=2,
jitter=(1, 3),
)
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
""" """
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub, Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub,
@ -167,9 +197,10 @@ class GithubProvider(GitProvider):
try: try:
names_original = [file.filename for file in files_original] names_original = [file.filename for file in files_original]
names_new = [file.filename for file in files] names_new = [file.filename for file in files]
get_logger().info(f"Filtered out [ignore] files for pull request:", extra= get_logger().info(
{"files": names_original, f"Filtered out [ignore] files for pull request:",
"filtered_files": names_new}) extra={"files": names_original, "filtered_files": names_new},
)
except Exception: except Exception:
pass pass
@ -184,14 +215,17 @@ class GithubProvider(GitProvider):
repo = self.repo_obj repo = self.repo_obj
pr = self.pr pr = self.pr
try: try:
compare = repo.compare(pr.base.sha, pr.head.sha) # communication with GitHub compare = repo.compare(
pr.base.sha, pr.head.sha
) # communication with GitHub
merge_base_commit = compare.merge_base_commit merge_base_commit = compare.merge_base_commit
except Exception as e: except Exception as e:
get_logger().error(f"Failed to get merge base commit: {e}") get_logger().error(f"Failed to get merge base commit: {e}")
merge_base_commit = pr.base merge_base_commit = pr.base
if merge_base_commit.sha != pr.base.sha: if merge_base_commit.sha != pr.base.sha:
get_logger().info( get_logger().info(
f"Using merge base commit {merge_base_commit.sha} instead of base commit ") f"Using merge base commit {merge_base_commit.sha} instead of base commit "
)
counter_valid = 0 counter_valid = 0
for file in files: for file in files:
@ -207,29 +241,48 @@ class GithubProvider(GitProvider):
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only # allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
counter_valid += 1 counter_valid += 1
avoid_load = False avoid_load = False
if counter_valid >= MAX_FILES_ALLOWED_FULL and patch and not self.incremental.is_incremental: if (
counter_valid >= MAX_FILES_ALLOWED_FULL
and patch
and not self.incremental.is_incremental
):
avoid_load = True avoid_load = True
if counter_valid == MAX_FILES_ALLOWED_FULL: if counter_valid == MAX_FILES_ALLOWED_FULL:
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files") get_logger().info(
f"Too many files in PR, will avoid loading full content for rest of files"
)
if avoid_load: if avoid_load:
new_file_content_str = "" new_file_content_str = ""
else: else:
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub new_file_content_str = self._get_pr_file_content(
file, self.pr.head.sha
) # communication with GitHub
if self.incremental.is_incremental and self.unreviewed_files_set: if self.incremental.is_incremental and self.unreviewed_files_set:
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) original_file_content_str = self._get_pr_file_content(
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) file, self.incremental.last_seen_commit_sha
)
patch = load_large_diff(
file.filename,
new_file_content_str,
original_file_content_str,
)
self.unreviewed_files_set[file.filename] = patch self.unreviewed_files_set[file.filename] = patch
else: else:
if avoid_load: if avoid_load:
original_file_content_str = "" original_file_content_str = ""
else: else:
original_file_content_str = self._get_pr_file_content(file, merge_base_commit.sha) original_file_content_str = self._get_pr_file_content(
file, merge_base_commit.sha
)
# original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) # original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
if not patch: if not patch:
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) patch = load_large_diff(
file.filename,
new_file_content_str,
original_file_content_str,
)
if file.status == 'added': if file.status == 'added':
edit_type = EDIT_TYPE.ADDED edit_type = EDIT_TYPE.ADDED
@ -249,16 +302,27 @@ class GithubProvider(GitProvider):
num_minus_lines = file.deletions num_minus_lines = file.deletions
else: else:
patch_lines = patch.splitlines(keepends=True) patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')]) num_plus_lines = len(
num_minus_lines = len([line for line in patch_lines if line.startswith('-')]) [line for line in patch_lines if line.startswith('+')]
)
num_minus_lines = len(
[line for line in patch_lines if line.startswith('-')]
)
file_patch_canonical_structure = FilePatchInfo(original_file_content_str, new_file_content_str, patch, file_patch_canonical_structure = FilePatchInfo(
file.filename, edit_type=edit_type, original_file_content_str,
new_file_content_str,
patch,
file.filename,
edit_type=edit_type,
num_plus_lines=num_plus_lines, num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines,) num_minus_lines=num_minus_lines,
)
diff_files.append(file_patch_canonical_structure) diff_files.append(file_patch_canonical_structure)
if invalid_files_names: if invalid_files_names:
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}") get_logger().info(
f"Filtered out files with invalid extensions: {invalid_files_names}"
)
self.diff_files = diff_files self.diff_files = diff_files
try: try:
@ -269,8 +333,10 @@ class GithubProvider(GitProvider):
return diff_files return diff_files
except Exception as e: except Exception as e:
get_logger().error(f"Failing to get diff files: {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Failing to get diff files: {e}",
artifact={"traceback": traceback.format_exc()},
)
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
@ -282,16 +348,23 @@ class GithubProvider(GitProvider):
def get_comment_url(self, comment) -> str: def get_comment_url(self, comment) -> str:
return comment.html_url return comment.html_url
def publish_persistent_comment(self, pr_comment: str, def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str, initial_header: str,
update_header: bool = True, update_header: bool = True,
name='review', name='review',
final_update_message=True): final_update_message=True,
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message) ):
self.publish_persistent_comment_full(
pr_comment, initial_header, update_header, name, final_update_message
)
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress: if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}") get_logger().debug(
f"Skipping publish_comment for temporary comment: {pr_comment}"
)
return None return None
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars) pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars)
response = self.pr.create_issue_comment(pr_comment) response = self.pr.create_issue_comment(pr_comment)
@ -303,34 +376,58 @@ class GithubProvider(GitProvider):
self.pr.comments_list.append(response) self.pr.comments_list.append(response)
return response return response
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
body = self.limit_output_characters(body, self.max_comment_chars) body = self.limit_output_characters(body, self.max_comment_chars)
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)]) self.publish_inline_comments(
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
)
def create_inline_comment(
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, self,
absolute_position: int = None): body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
body = self.limit_output_characters(body, self.max_comment_chars) body = self.limit_output_characters(body, self.max_comment_chars)
position, absolute_position = find_line_number_of_relevant_line_in_file(self.diff_files, position, absolute_position = find_line_number_of_relevant_line_in_file(
self.diff_files,
relevant_file.strip('`'), relevant_file.strip('`'),
relevant_line_in_file, relevant_line_in_file,
absolute_position) absolute_position,
)
if position == -1: if position == -1:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
subject_type = "FILE" subject_type = "FILE"
else: else:
subject_type = "LINE" subject_type = "LINE"
path = relevant_file.strip() path = relevant_file.strip()
return dict(body=body, path=path, position=position) if subject_type == "LINE" else {} return (
dict(body=body, path=path, position=position)
if subject_type == "LINE"
else {}
)
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False): def publish_inline_comments(
self, comments: list[dict], disable_fallback: bool = False
):
try: try:
# publish all comments in a single message # publish all comments in a single message
self.pr.create_review(commit=self.last_commit_id, comments=comments) self.pr.create_review(commit=self.last_commit_id, comments=comments)
except Exception as e: except Exception as e:
get_logger().info(f"Initially failed to publish inline comments as committable") get_logger().info(
f"Initially failed to publish inline comments as committable"
)
if (getattr(e, "status", None) == 422 and not disable_fallback): if getattr(e, "status", None) == 422 and not disable_fallback:
pass # continue to try _publish_inline_comments_fallback_with_verification pass # continue to try _publish_inline_comments_fallback_with_verification
else: else:
raise e # will end up with publishing the comments one by one raise e # will end up with publishing the comments one by one
@ -338,7 +435,9 @@ class GithubProvider(GitProvider):
try: try:
self._publish_inline_comments_fallback_with_verification(comments) self._publish_inline_comments_fallback_with_verification(comments)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to publish inline code comments fallback, error: {e}") get_logger().error(
f"Failed to publish inline code comments fallback, error: {e}"
)
raise e raise e
def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]): def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]):
@ -352,20 +451,27 @@ class GithubProvider(GitProvider):
# publish as a group the verified comments # publish as a group the verified comments
if verified_comments: if verified_comments:
try: try:
self.pr.create_review(commit=self.last_commit_id, comments=verified_comments) self.pr.create_review(
commit=self.last_commit_id, comments=verified_comments
)
except: except:
pass pass
# try to publish one by one the invalid comments as a one-line code comment # try to publish one by one the invalid comments as a one-line code comment
if invalid_comments and get_settings().github.try_fix_invalid_inline_comments: if invalid_comments and get_settings().github.try_fix_invalid_inline_comments:
fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments( fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments(
[comment for comment, _ in invalid_comments]) [comment for comment, _ in invalid_comments]
)
for comment in fixed_comments_as_one_liner: for comment in fixed_comments_as_one_liner:
try: try:
self.publish_inline_comments([comment], disable_fallback=True) self.publish_inline_comments([comment], disable_fallback=True)
get_logger().info(f"Published invalid comment as a single line comment: {comment}") get_logger().info(
f"Published invalid comment as a single line comment: {comment}"
)
except: except:
get_logger().error(f"Failed to publish invalid comment as a single line comment: {comment}") get_logger().error(
f"Failed to publish invalid comment as a single line comment: {comment}"
)
def _verify_code_comment(self, comment: dict): def _verify_code_comment(self, comment: dict):
is_verified = False is_verified = False
@ -374,7 +480,8 @@ class GithubProvider(GitProvider):
# event ="" # By leaving this blank, you set the review action state to PENDING # event ="" # By leaving this blank, you set the review action state to PENDING
input = dict(commit_id=self.last_commit_id.sha, comments=[comment]) input = dict(commit_id=self.last_commit_id.sha, comments=[comment])
headers, data = self.pr._requester.requestJsonAndCheck( headers, data = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.pr.url}/reviews", input=input) "POST", f"{self.pr.url}/reviews", input=input
)
pending_review_id = data["id"] pending_review_id = data["id"]
is_verified = True is_verified = True
except Exception as err: except Exception as err:
@ -383,12 +490,16 @@ class GithubProvider(GitProvider):
e = err e = err
if pending_review_id is not None: if pending_review_id is not None:
try: try:
self.pr._requester.requestJsonAndCheck("DELETE", f"{self.pr.url}/reviews/{pending_review_id}") self.pr._requester.requestJsonAndCheck(
"DELETE", f"{self.pr.url}/reviews/{pending_review_id}"
)
except Exception: except Exception:
pass pass
return is_verified, e return is_verified, e
def _verify_code_comments(self, comments: list[dict]) -> tuple[list[dict], list[tuple[dict, Exception]]]: def _verify_code_comments(
self, comments: list[dict]
) -> tuple[list[dict], list[tuple[dict, Exception]]]:
"""Very each comment against the GitHub API and return 2 lists: 1 of verified and 1 of invalid comments""" """Very each comment against the GitHub API and return 2 lists: 1 of verified and 1 of invalid comments"""
verified_comments = [] verified_comments = []
invalid_comments = [] invalid_comments = []
@ -401,17 +512,22 @@ class GithubProvider(GitProvider):
invalid_comments.append((comment, e)) invalid_comments.append((comment, e))
return verified_comments, invalid_comments return verified_comments, invalid_comments
def _try_fix_invalid_inline_comments(self, invalid_comments: list[dict]) -> list[dict]: def _try_fix_invalid_inline_comments(
self, invalid_comments: list[dict]
) -> list[dict]:
""" """
Try fixing invalid comments by removing the suggestion part and setting the comment just on the first line. Try fixing invalid comments by removing the suggestion part and setting the comment just on the first line.
Return only comments that have been modified in some way. Return only comments that have been modified in some way.
This is a best-effort attempt to fix invalid comments, and should be verified accordingly. This is a best-effort attempt to fix invalid comments, and should be verified accordingly.
""" """
import copy import copy
fixed_comments = [] fixed_comments = []
for comment in invalid_comments: for comment in invalid_comments:
try: try:
fixed_comment = copy.deepcopy(comment) # avoid modifying the original comment dict for later logging fixed_comment = copy.deepcopy(
comment
) # avoid modifying the original comment dict for later logging
if "```suggestion" in comment["body"]: if "```suggestion" in comment["body"]:
fixed_comment["body"] = comment["body"].split("```suggestion")[0] fixed_comment["body"] = comment["body"].split("```suggestion")[0]
if "start_line" in comment: if "start_line" in comment:
@ -432,7 +548,9 @@ class GithubProvider(GitProvider):
""" """
post_parameters_list = [] post_parameters_list = []
code_suggestions_validated = self.validate_comments_inside_hunks(code_suggestions) code_suggestions_validated = self.validate_comments_inside_hunks(
code_suggestions
)
for suggestion in code_suggestions_validated: for suggestion in code_suggestions_validated:
body = suggestion['body'] body = suggestion['body']
@ -442,13 +560,16 @@ class GithubProvider(GitProvider):
if not relevant_lines_start or relevant_lines_start == -1: if not relevant_lines_start or relevant_lines_start == -1:
get_logger().exception( get_logger().exception(
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
)
continue continue
if relevant_lines_end < relevant_lines_start: if relevant_lines_end < relevant_lines_start:
get_logger().exception(f"Failed to publish code suggestion, " get_logger().exception(
f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and " f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}") f"relevant_lines_start is {relevant_lines_start}"
)
continue continue
if relevant_lines_end > relevant_lines_start: if relevant_lines_end > relevant_lines_start:
@ -484,17 +605,21 @@ class GithubProvider(GitProvider):
# Log as warning for permission-related issues (usually due to polling) # Log as warning for permission-related issues (usually due to polling)
get_logger().warning( get_logger().warning(
"Failed to edit github comment due to permission restrictions", "Failed to edit github comment due to permission restrictions",
artifact={"error": e}) artifact={"error": e},
)
else: else:
get_logger().exception(f"Failed to edit github comment", artifact={"error": e}) get_logger().exception(
f"Failed to edit github comment", artifact={"error": e}
)
def edit_comment_from_comment_id(self, comment_id: int, body: str): def edit_comment_from_comment_id(self, comment_id: int, body: str):
try: try:
# self.pr.get_issue_comment(comment_id).edit(body) # self.pr.get_issue_comment(comment_id).edit(body)
body = self.limit_output_characters(body, self.max_comment_chars) body = self.limit_output_characters(body, self.max_comment_chars)
headers, data_patch = self.pr._requester.requestJsonAndCheck( headers, data_patch = self.pr._requester.requestJsonAndCheck(
"PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}", "PATCH",
input={"body": body} f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
input={"body": body},
) )
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to edit comment, error: {e}") get_logger().exception(f"Failed to edit comment, error: {e}")
@ -504,8 +629,9 @@ class GithubProvider(GitProvider):
# self.pr.get_issue_comment(comment_id).edit(body) # self.pr.get_issue_comment(comment_id).edit(body)
body = self.limit_output_characters(body, self.max_comment_chars) body = self.limit_output_characters(body, self.max_comment_chars)
headers, data_patch = self.pr._requester.requestJsonAndCheck( headers, data_patch = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies", "POST",
input={"body": body} f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies",
input={"body": body},
) )
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to reply comment, error: {e}") get_logger().exception(f"Failed to reply comment, error: {e}")
@ -528,7 +654,9 @@ class GithubProvider(GitProvider):
) )
for comment in file_comments: for comment in file_comments:
comment['commit_id'] = self.last_commit_id.sha comment['commit_id'] = self.last_commit_id.sha
comment['body'] = self.limit_output_characters(comment['body'], self.max_comment_chars) comment['body'] = self.limit_output_characters(
comment['body'], self.max_comment_chars
)
found = False found = False
for existing_comment in existing_comments: for existing_comment in existing_comments:
@ -536,13 +664,23 @@ class GithubProvider(GitProvider):
our_app_name = get_settings().get("GITHUB.APP_NAME", "") our_app_name = get_settings().get("GITHUB.APP_NAME", "")
same_comment_creator = False same_comment_creator = False
if self.deployment_type == 'app': if self.deployment_type == 'app':
same_comment_creator = our_app_name.lower() in existing_comment['user']['login'].lower() same_comment_creator = (
our_app_name.lower()
in existing_comment['user']['login'].lower()
)
elif self.deployment_type == 'user': elif self.deployment_type == 'user':
same_comment_creator = self.github_user_id == existing_comment['user']['login'] same_comment_creator = (
if existing_comment['subject_type'] == 'file' and comment['path'] == existing_comment['path'] and same_comment_creator: self.github_user_id == existing_comment['user']['login']
)
if (
existing_comment['subject_type'] == 'file'
and comment['path'] == existing_comment['path']
and same_comment_creator
):
headers, data_patch = self.pr._requester.requestJsonAndCheck( headers, data_patch = self.pr._requester.requestJsonAndCheck(
"PATCH", f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", input={"body":comment['body']} "PATCH",
f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}",
input={"body": comment['body']},
) )
found = True found = True
break break
@ -600,7 +738,9 @@ class GithubProvider(GitProvider):
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
if deployment_type != 'user': if deployment_type != 'user':
raise ValueError("Deployment mode must be set to 'user' to get notifications") raise ValueError(
"Deployment mode must be set to 'user' to get notifications"
)
notifications = self.github_client.get_user().get_notifications(since=since) notifications = self.github_client.get_user().get_notifications(since=since)
return notifications return notifications
@ -621,13 +761,16 @@ class GithubProvider(GitProvider):
def get_workspace_name(self): def get_workspace_name(self):
return self.repo.split('/')[0] return self.repo.split('/')[0]
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
if disable_eyes: if disable_eyes:
return None return None
try: try:
headers, data_patch = self.pr._requester.requestJsonAndCheck( headers, data_patch = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions", "POST",
input={"content": "eyes"} f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions",
input={"content": "eyes"},
) )
return data_patch.get("id", None) return data_patch.get("id", None)
except Exception as e: except Exception as e:
@ -639,7 +782,7 @@ class GithubProvider(GitProvider):
# self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id) # self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id)
headers, data_patch = self.pr._requester.requestJsonAndCheck( headers, data_patch = self.pr._requester.requestJsonAndCheck(
"DELETE", "DELETE",
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}" f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}",
) )
return True return True
except Exception as e: except Exception as e:
@ -655,7 +798,9 @@ class GithubProvider(GitProvider):
path_parts = parsed_url.path.strip('/').split('/') path_parts = parsed_url.path.strip('/').split('/')
if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url: if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url:
if len(path_parts) < 5 or path_parts[3] != 'pulls': if len(path_parts) < 5 or path_parts[3] != 'pulls':
raise ValueError("The provided URL does not appear to be a GitHub PR URL") raise ValueError(
"The provided URL does not appear to be a GitHub PR URL"
)
repo_name = '/'.join(path_parts[1:3]) repo_name = '/'.join(path_parts[1:3])
try: try:
pr_number = int(path_parts[4]) pr_number = int(path_parts[4])
@ -683,7 +828,9 @@ class GithubProvider(GitProvider):
path_parts = parsed_url.path.strip('/').split('/') path_parts = parsed_url.path.strip('/').split('/')
if 'api.github.com' in parsed_url.netloc: if 'api.github.com' in parsed_url.netloc:
if len(path_parts) < 5 or path_parts[3] != 'issues': if len(path_parts) < 5 or path_parts[3] != 'issues':
raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL") raise ValueError(
"The provided URL does not appear to be a GitHub ISSUE URL"
)
repo_name = '/'.join(path_parts[1:3]) repo_name = '/'.join(path_parts[1:3])
try: try:
issue_number = int(path_parts[4]) issue_number = int(path_parts[4])
@ -710,11 +857,18 @@ class GithubProvider(GitProvider):
private_key = get_settings().github.private_key private_key = get_settings().github.private_key
app_id = get_settings().github.app_id app_id = get_settings().github.app_id
except AttributeError as e: except AttributeError as e:
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e raise ValueError(
"GitHub app ID and private key are required when using GitHub app deployment"
) from e
if not self.installation_id: if not self.installation_id:
raise ValueError("GitHub app installation ID is required when using GitHub app deployment") raise ValueError(
auth = AppAuthentication(app_id=app_id, private_key=private_key, "GitHub app installation ID is required when using GitHub app deployment"
installation_id=self.installation_id) )
auth = AppAuthentication(
app_id=app_id,
private_key=private_key,
installation_id=self.installation_id,
)
return Github(app_auth=auth, base_url=self.base_url) return Github(app_auth=auth, base_url=self.base_url)
if deployment_type == 'user': if deployment_type == 'user':
@ -723,19 +877,21 @@ class GithubProvider(GitProvider):
except AttributeError as e: except AttributeError as e:
raise ValueError( raise ValueError(
"GitHub token is required when using user deployment. See: " "GitHub token is required when using user deployment. See: "
"https://github.com/Codium-ai/pr-agent#method-2-run-from-source") from e "https://github.com/Codium-ai/pr-agent#method-2-run-from-source"
) from e
return Github(auth=Auth.Token(token), base_url=self.base_url) return Github(auth=Auth.Token(token), base_url=self.base_url)
def _get_repo(self): def _get_repo(self):
if hasattr(self, 'repo_obj') and \ if (
hasattr(self.repo_obj, 'full_name') and \ hasattr(self, 'repo_obj')
self.repo_obj.full_name == self.repo: and hasattr(self.repo_obj, 'full_name')
and self.repo_obj.full_name == self.repo
):
return self.repo_obj return self.repo_obj
else: else:
self.repo_obj = self.github_client.get_repo(self.repo) self.repo_obj = self.github_client.get_repo(self.repo)
return self.repo_obj return self.repo_obj
def _get_pr(self): def _get_pr(self):
return self._get_repo().get_pull(self.pr_num) return self._get_repo().get_pull(self.pr_num)
@ -771,9 +927,14 @@ class GithubProvider(GitProvider):
def publish_labels(self, pr_types): def publish_labels(self, pr_types):
try: try:
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", label_color_map = {
"Enhancement": "bfd4f2", "Documentation": "d4c5f9", "Bug fix": "1d76db",
"Other": "d1bcf9"} "Tests": "e99695",
"Bug fix with tests": "c5def5",
"Enhancement": "bfd4f2",
"Documentation": "d4c5f9",
"Other": "d1bcf9",
}
post_parameters = [] post_parameters = []
for p in pr_types: for p in pr_types:
color = label_color_map.get(p, "d1bcf9") # default to "Other" color color = label_color_map.get(p, "d1bcf9") # default to "Other" color
@ -791,7 +952,8 @@ class GithubProvider(GitProvider):
return [label.name for label in labels] return [label.name for label in labels]
else: # obtain the latest labels. Maybe they changed while the AI was running else: # obtain the latest labels. Maybe they changed while the AI was running
headers, labels = self.pr._requester.requestJsonAndCheck( headers, labels = self.pr._requester.requestJsonAndCheck(
"GET", f"{self.pr.issue_url}/labels") "GET", f"{self.pr.issue_url}/labels"
)
return [label['name'] for label in labels] return [label['name'] for label in labels]
except Exception as e: except Exception as e:
@ -813,7 +975,9 @@ class GithubProvider(GitProvider):
try: try:
commit_list = self.pr.get_commits() commit_list = self.pr.get_commits()
commit_messages = [commit.commit.message for commit in commit_list] commit_messages = [commit.commit.message for commit in commit_list]
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]) commit_messages_str = "\n".join(
[f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]
)
except Exception: except Exception:
commit_messages_str = "" commit_messages_str = ""
if max_tokens: if max_tokens:
@ -822,13 +986,16 @@ class GithubProvider(GitProvider):
def generate_link_to_relevant_line_number(self, suggestion) -> str: def generate_link_to_relevant_line_number(self, suggestion) -> str:
try: try:
relevant_file = suggestion['relevant_file'].strip('`').strip("'").strip('\n') relevant_file = (
suggestion['relevant_file'].strip('`').strip("'").strip('\n')
)
relevant_line_str = suggestion['relevant_line'].strip('\n') relevant_line_str = suggestion['relevant_line'].strip('\n')
if not relevant_line_str: if not relevant_line_str:
return "" return ""
position, absolute_position = find_line_number_of_relevant_line_in_file \ position, absolute_position = find_line_number_of_relevant_line_in_file(
(self.diff_files, relevant_file, relevant_line_str) self.diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1: if absolute_position != -1:
# # link to right file only # # link to right file only
@ -844,7 +1011,12 @@ class GithubProvider(GitProvider):
return "" return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest() sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest()
if relevant_line_start == -1: if relevant_line_start == -1:
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}" link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}"
@ -854,7 +1026,9 @@ class GithubProvider(GitProvider):
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}" link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}"
return link return link
def get_lines_link_original_file(self, filepath: str, component_range: Range) -> str: def get_lines_link_original_file(
self, filepath: str, component_range: Range
) -> str:
""" """
Returns the link to the original file on GitHub that corresponds to the given filepath and component range. Returns the link to the original file on GitHub that corresponds to the given filepath and component range.
@ -876,8 +1050,10 @@ class GithubProvider(GitProvider):
line_end = component_range.line_end + 1 line_end = component_range.line_end + 1
# link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/" # link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
# f"#L{line_start}-L{line_end}") # f"#L{line_start}-L{line_end}")
link = (f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/" link = (
f"#L{line_start}-L{line_end}") f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
f"#L{line_start}-L{line_end}"
)
return link return link
@ -909,8 +1085,9 @@ class GithubProvider(GitProvider):
}} }}
}} }}
""" """
response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", response_tuple = self.github_client._Github__requester.requestJson(
input={"query": query}) "POST", "/graphql", input={"query": query}
)
# Extract the JSON response from the tuple and parses it # Extract the JSON response from the tuple and parses it
if isinstance(response_tuple, tuple) and len(response_tuple) == 3: if isinstance(response_tuple, tuple) and len(response_tuple) == 3:
@ -919,8 +1096,12 @@ class GithubProvider(GitProvider):
get_logger().error(f"Unexpected response format: {response_tuple}") get_logger().error(f"Unexpected response format: {response_tuple}")
return sub_issues return sub_issues
issue_id = (
issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id") response_json.get("data", {})
.get("repository", {})
.get("issue", {})
.get("id")
)
if not issue_id: if not issue_id:
get_logger().warning(f"Issue ID not found for {issue_url}") get_logger().warning(f"Issue ID not found for {issue_url}")
@ -940,22 +1121,42 @@ class GithubProvider(GitProvider):
}} }}
}} }}
""" """
sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={ sub_issues_response_tuple = (
"query": sub_issues_query}) self.github_client._Github__requester.requestJson(
"POST", "/graphql", input={"query": sub_issues_query}
)
)
# Extract the JSON response from the tuple and parses it # Extract the JSON response from the tuple and parses it
if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3: if (
isinstance(sub_issues_response_tuple, tuple)
and len(sub_issues_response_tuple) == 3
):
sub_issues_response_json = json.loads(sub_issues_response_tuple[2]) sub_issues_response_json = json.loads(sub_issues_response_tuple[2])
else: else:
get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple}) get_logger().error(
"Unexpected sub-issues response format",
artifact={"response": sub_issues_response_tuple},
)
return sub_issues return sub_issues
if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"): if (
not sub_issues_response_json.get("data", {})
.get("node", {})
.get("subIssues")
):
get_logger().error("Invalid sub-issues response structure") get_logger().error("Invalid sub-issues response structure")
return sub_issues return sub_issues
nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", []) nodes = (
get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}) sub_issues_response_json.get("data", {})
.get("node", {})
.get("subIssues", {})
.get("nodes", [])
)
get_logger().info(
f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}
)
for sub_issue in nodes: for sub_issue in nodes:
if "url" in sub_issue: if "url" in sub_issue:
@ -986,7 +1187,8 @@ class GithubProvider(GitProvider):
code_suggestions_copy = copy.deepcopy(code_suggestions) code_suggestions_copy = copy.deepcopy(code_suggestions)
diff_files = self.get_diff_files() diff_files = self.get_diff_files()
RE_HUNK_HEADER = re.compile( RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
)
diff_files = set_file_languages(diff_files) diff_files = set_file_languages(diff_files)
@ -995,7 +1197,6 @@ class GithubProvider(GitProvider):
relevant_file_path = suggestion['relevant_file'] relevant_file_path = suggestion['relevant_file']
for file in diff_files: for file in diff_files:
if file.filename == relevant_file_path: if file.filename == relevant_file_path:
# generate on-demand the patches range for the relevant file # generate on-demand the patches range for the relevant file
patch_str = file.patch patch_str = file.patch
if not hasattr(file, 'patches_range'): if not hasattr(file, 'patches_range'):
@ -1006,14 +1207,30 @@ class GithubProvider(GitProvider):
match = RE_HUNK_HEADER.match(line) match = RE_HUNK_HEADER.match(line)
# identify hunk header # identify hunk header
if match: if match:
section_header, size1, size2, start1, start2 = extract_hunk_headers(match) (
file.patches_range.append({'start': start2, 'end': start2 + size2 - 1}) section_header,
size1,
size2,
start1,
start2,
) = extract_hunk_headers(match)
file.patches_range.append(
{'start': start2, 'end': start2 + size2 - 1}
)
patches_range = file.patches_range patches_range = file.patches_range
comment_start_line = suggestion.get('relevant_lines_start', None) comment_start_line = suggestion.get(
'relevant_lines_start', None
)
comment_end_line = suggestion.get('relevant_lines_end', None) comment_end_line = suggestion.get('relevant_lines_end', None)
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code original_suggestion = suggestion.get(
if not comment_start_line or not comment_end_line or not original_suggestion: 'original_suggestion', None
) # needed for diff code
if (
not comment_start_line
or not comment_end_line
or not original_suggestion
):
continue continue
# check if the comment is inside a valid hunk # check if the comment is inside a valid hunk
@ -1037,30 +1254,57 @@ class GithubProvider(GitProvider):
patch_range_min = patch_range patch_range_min = patch_range
min_distance = min(min_distance, d) min_distance = min(min_distance, d)
if not is_valid_hunk: if not is_valid_hunk:
if min_distance < 10: # 10 lines - a reasonable distance to consider the comment inside the hunk if (
min_distance < 10
): # 10 lines - a reasonable distance to consider the comment inside the hunk
# make the suggestion non-committable, yet multi line # make the suggestion non-committable, yet multi line
suggestion['relevant_lines_start'] = max(suggestion['relevant_lines_start'], patch_range_min['start']) suggestion['relevant_lines_start'] = max(
suggestion['relevant_lines_end'] = min(suggestion['relevant_lines_end'], patch_range_min['end']) suggestion['relevant_lines_start'],
patch_range_min['start'],
)
suggestion['relevant_lines_end'] = min(
suggestion['relevant_lines_end'],
patch_range_min['end'],
)
body = suggestion['body'].strip() body = suggestion['body'].strip()
# present new diff code in collapsible # present new diff code in collapsible
existing_code = original_suggestion['existing_code'].rstrip() + "\n" existing_code = (
improved_code = original_suggestion['improved_code'].rstrip() + "\n" original_suggestion['existing_code'].rstrip() + "\n"
diff = difflib.unified_diff(existing_code.split('\n'), )
improved_code.split('\n'), n=999) improved_code = (
original_suggestion['improved_code'].rstrip() + "\n"
)
diff = difflib.unified_diff(
existing_code.split('\n'),
improved_code.split('\n'),
n=999,
)
patch_orig = "\n".join(diff) patch_orig = "\n".join(diff)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') patch = "\n".join(patch_orig.splitlines()[5:]).strip(
'\n'
)
diff_code = f"\n\n<details><summary>新提议的代码:</summary>\n\n```diff\n{patch.rstrip()}\n```" diff_code = f"\n\n<details><summary>新提议的代码:</summary>\n\n```diff\n{patch.rstrip()}\n```"
# replace ```suggestion ... ``` with diff_code, using regex: # replace ```suggestion ... ``` with diff_code, using regex:
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL) body = re.sub(
r'```suggestion.*?```',
diff_code,
body,
flags=re.DOTALL,
)
body += "\n\n</details>" body += "\n\n</details>"
suggestion['body'] = body suggestion['body'] = body
get_logger().info(f"Comment was moved to a valid hunk, " get_logger().info(
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}") f"Comment was moved to a valid hunk, "
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
)
else: else:
get_logger().error(f"Comment is not inside a valid hunk, " get_logger().error(
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}") f"Comment is not inside a valid hunk, "
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to process patch for committable comment, error: {e}") get_logger().error(
f"Failed to process patch for committable comment, error: {e}"
)
return code_suggestions_copy return code_suggestions_copy

View File

@ -10,9 +10,11 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from ..algo.file_filter import filter_ignored from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import (clip_tokens, from ..algo.utils import (
clip_tokens,
find_line_number_of_relevant_line_in_file, find_line_number_of_relevant_line_in_file,
load_large_diff) load_large_diff,
)
from ..config_loader import get_settings from ..config_loader import get_settings
from ..log import get_logger from ..log import get_logger
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
@ -20,22 +22,26 @@ from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
class DiffNotFoundError(Exception): class DiffNotFoundError(Exception):
"""Raised when the diff for a merge request cannot be found.""" """Raised when the diff for a merge request cannot be found."""
pass pass
class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): class GitLabProvider(GitProvider):
def __init__(
self,
merge_request_url: Optional[str] = None,
incremental: Optional[bool] = False,
):
gitlab_url = get_settings().get("GITLAB.URL", None) gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
raise ValueError("GitLab URL is not set in the config file") raise ValueError("GitLab URL is not set in the config file")
self.gitlab_url = gitlab_url self.gitlab_url = gitlab_url
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None) gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_access_token: if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file") raise ValueError(
self.gl = gitlab.Gitlab( "GitLab personal access token is not set in the config file"
url=gitlab_url,
oauth_token=gitlab_access_token
) )
self.gl = gitlab.Gitlab(url=gitlab_url, oauth_token=gitlab_access_token)
self.max_comment_chars = 65000 self.max_comment_chars = 65000
self.id_project = None self.id_project = None
self.id_mr = None self.id_mr = None
@ -46,12 +52,17 @@ class GitLabProvider(GitProvider):
self.pr_url = merge_request_url self.pr_url = merge_request_url
self._set_merge_request(merge_request_url) self._set_merge_request(merge_request_url)
self.RE_HUNK_HEADER = re.compile( self.RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
)
self.incremental = incremental self.incremental = incremental
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', if capability in [
'publish_file_comments']: # gfm_markdown is supported in gitlab ! 'get_issue_comments',
'create_inline_comment',
'publish_inline_comments',
'publish_file_comments',
]: # gfm_markdown is supported in gitlab !
return False return False
return True return True
@ -67,12 +78,17 @@ class GitLabProvider(GitProvider):
self.last_diff = self.mr.diffs.list(get_all=True)[-1] self.last_diff = self.mr.diffs.list(get_all=True)[-1]
except IndexError as e: except IndexError as e:
get_logger().error(f"Could not get diff for merge request {self.id_mr}") get_logger().error(f"Could not get diff for merge request {self.id_mr}")
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") from e raise DiffNotFoundError(
f"Could not get diff for merge request {self.id_mr}"
) from e
def get_pr_file_content(self, file_path: str, branch: str) -> str: def get_pr_file_content(self, file_path: str, branch: str) -> str:
try: try:
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode() return (
self.gl.projects.get(self.id_project)
.files.get(file_path, branch)
.decode()
)
except GitlabGetError: except GitlabGetError:
# In case of file creation the method returns GitlabGetError (404 file not found). # In case of file creation the method returns GitlabGetError (404 file not found).
# In this case we return an empty string for the diff. # In this case we return an empty string for the diff.
@ -98,10 +114,13 @@ class GitLabProvider(GitProvider):
try: try:
names_original = [diff['new_path'] for diff in diffs_original] names_original = [diff['new_path'] for diff in diffs_original]
names_filtered = [diff['new_path'] for diff in diffs] names_filtered = [diff['new_path'] for diff in diffs]
get_logger().info(f"Filtered out [ignore] files for merge request {self.id_mr}", extra={ get_logger().info(
f"Filtered out [ignore] files for merge request {self.id_mr}",
extra={
'original_files': names_original, 'original_files': names_original,
'filtered_files': names_filtered 'filtered_files': names_filtered,
}) },
)
except Exception as e: except Exception as e:
pass pass
@ -116,22 +135,31 @@ class GitLabProvider(GitProvider):
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only # allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
counter_valid += 1 counter_valid += 1
if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']: if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']:
original_file_content_str = self.get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha']) original_file_content_str = self.get_pr_file_content(
new_file_content_str = self.get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha']) diff['old_path'], self.mr.diff_refs['base_sha']
)
new_file_content_str = self.get_pr_file_content(
diff['new_path'], self.mr.diff_refs['head_sha']
)
else: else:
if counter_valid == MAX_FILES_ALLOWED_FULL: if counter_valid == MAX_FILES_ALLOWED_FULL:
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files") get_logger().info(
f"Too many files in PR, will avoid loading full content for rest of files"
)
original_file_content_str = '' original_file_content_str = ''
new_file_content_str = '' new_file_content_str = ''
try: try:
if isinstance(original_file_content_str, bytes): if isinstance(original_file_content_str, bytes):
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8') original_file_content_str = bytes.decode(
original_file_content_str, 'utf-8'
)
if isinstance(new_file_content_str, bytes): if isinstance(new_file_content_str, bytes):
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8') new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
get_logger().warning( get_logger().warning(
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}") f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}"
)
edit_type = EDIT_TYPE.MODIFIED edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']: if diff['new_file']:
@ -144,30 +172,43 @@ class GitLabProvider(GitProvider):
filename = diff['new_path'] filename = diff['new_path']
patch = diff['diff'] patch = diff['diff']
if not patch: if not patch:
patch = load_large_diff(filename, new_file_content_str, original_file_content_str) patch = load_large_diff(
filename, new_file_content_str, original_file_content_str
)
# count number of lines added and removed # count number of lines added and removed
patch_lines = patch.splitlines(keepends=True) patch_lines = patch.splitlines(keepends=True)
num_plus_lines = len([line for line in patch_lines if line.startswith('+')]) num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
num_minus_lines = len([line for line in patch_lines if line.startswith('-')]) num_minus_lines = len(
[line for line in patch_lines if line.startswith('-')]
)
diff_files.append( diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, FilePatchInfo(
original_file_content_str,
new_file_content_str,
patch=patch, patch=patch,
filename=filename, filename=filename,
edit_type=edit_type, edit_type=edit_type,
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'], old_filename=None
if diff['old_path'] == diff['new_path']
else diff['old_path'],
num_plus_lines=num_plus_lines, num_plus_lines=num_plus_lines,
num_minus_lines=num_minus_lines, )) num_minus_lines=num_minus_lines,
)
)
if invalid_files_names: if invalid_files_names:
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}") get_logger().info(
f"Filtered out files with invalid extensions: {invalid_files_names}"
)
self.diff_files = diff_files self.diff_files = diff_files
return diff_files return diff_files
def get_files(self) -> list: def get_files(self) -> list:
if not self.git_files: if not self.git_files:
self.git_files = [change['new_path'] for change in self.mr.changes()['changes']] self.git_files = [
change['new_path'] for change in self.mr.changes()['changes']
]
return self.git_files return self.git_files
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
@ -176,7 +217,9 @@ class GitLabProvider(GitProvider):
self.mr.description = pr_body self.mr.description = pr_body
self.mr.save() self.mr.save()
except Exception as e: except Exception as e:
get_logger().exception(f"Could not update merge request {self.id_mr} description: {e}") get_logger().exception(
f"Could not update merge request {self.id_mr} description: {e}"
)
def get_latest_commit_url(self): def get_latest_commit_url(self):
return self.mr.commits().next().web_url return self.mr.commits().next().web_url
@ -184,16 +227,23 @@ class GitLabProvider(GitProvider):
def get_comment_url(self, comment): def get_comment_url(self, comment):
return f"{self.mr.web_url}#note_{comment.id}" return f"{self.mr.web_url}#note_{comment.id}"
def publish_persistent_comment(self, pr_comment: str, def publish_persistent_comment(
self,
pr_comment: str,
initial_header: str, initial_header: str,
update_header: bool = True, update_header: bool = True,
name='review', name='review',
final_update_message=True): final_update_message=True,
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message) ):
self.publish_persistent_comment_full(
pr_comment, initial_header, update_header, name, final_update_message
)
def publish_comment(self, mr_comment: str, is_temporary: bool = False): def publish_comment(self, mr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress: if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {mr_comment}") get_logger().debug(
f"Skipping publish_comment for temporary comment: {mr_comment}"
)
return None return None
mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars) mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars)
comment = self.mr.notes.create({'body': mr_comment}) comment = self.mr.notes.create({'body': mr_comment})
@ -216,39 +266,87 @@ class GitLabProvider(GitProvider):
discussion = self.mr.discussions.get(comment_id) discussion = self.mr.discussions.get(comment_id)
discussion.notes.create({'body': body}) discussion.notes.create({'body': body})
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): def publish_inline_comment(
self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
body = self.limit_output_characters(body, self.max_comment_chars) body = self.limit_output_characters(body, self.max_comment_chars)
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file, (
relevant_line_in_file) edit_type,
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, found,
target_file, target_line_no, original_suggestion) source_line_no,
target_file,
target_line_no,
) = self.search_line(relevant_file, relevant_line_in_file)
self.send_inline_comment(
body,
edit_type,
found,
relevant_file,
relevant_line_in_file,
source_line_no,
target_file,
target_line_no,
original_suggestion,
)
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None): def create_inline_comment(
raise NotImplementedError("Gitlab provider does not support creating inline comments yet") self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
absolute_position: int = None,
):
raise NotImplementedError(
"Gitlab provider does not support creating inline comments yet"
)
def create_inline_comments(self, comments: list[dict]): def create_inline_comments(self, comments: list[dict]):
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet") raise NotImplementedError(
"Gitlab provider does not support publishing inline comments yet"
)
def get_comment_body_from_comment_id(self, comment_id: int): def get_comment_body_from_comment_id(self, comment_id: int):
comment = self.mr.notes.get(comment_id).body comment = self.mr.notes.get(comment_id).body
return comment return comment
def send_inline_comment(self, body: str, edit_type: str, found: bool, relevant_file: str, def send_inline_comment(
self,
body: str,
edit_type: str,
found: bool,
relevant_file: str,
relevant_line_in_file: str, relevant_line_in_file: str,
source_line_no: int, target_file: str, target_line_no: int, source_line_no: int,
original_suggestion=None) -> None: target_file: str,
target_line_no: int,
original_suggestion=None,
) -> None:
if not found: if not found:
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") get_logger().info(
f"Could not find position for {relevant_file} {relevant_line_in_file}"
)
else: else:
# in order to have exact sha's we have to find correct diff for this change # in order to have exact sha's we have to find correct diff for this change
diff = self.get_relevant_diff(relevant_file, relevant_line_in_file) diff = self.get_relevant_diff(relevant_file, relevant_line_in_file)
if diff is None: if diff is None:
get_logger().error(f"Could not get diff for merge request {self.id_mr}") get_logger().error(f"Could not get diff for merge request {self.id_mr}")
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") raise DiffNotFoundError(
pos_obj = {'position_type': 'text', f"Could not get diff for merge request {self.id_mr}"
)
pos_obj = {
'position_type': 'text',
'new_path': target_file.filename, 'new_path': target_file.filename,
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename, 'old_path': target_file.old_filename
'base_sha': diff.base_commit_sha, 'start_sha': diff.start_commit_sha, 'head_sha': diff.head_commit_sha} if target_file.old_filename
else target_file.filename,
'base_sha': diff.base_commit_sha,
'start_sha': diff.start_commit_sha,
'head_sha': diff.head_commit_sha,
}
if edit_type == 'deletion': if edit_type == 'deletion':
pos_obj['old_line'] = source_line_no - 1 pos_obj['old_line'] = source_line_no - 1
elif edit_type == 'addition': elif edit_type == 'addition':
@ -256,15 +354,21 @@ class GitLabProvider(GitProvider):
else: else:
pos_obj['new_line'] = target_line_no - 1 pos_obj['new_line'] = target_line_no - 1
pos_obj['old_line'] = source_line_no - 1 pos_obj['old_line'] = source_line_no - 1
get_logger().debug(f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}") get_logger().debug(
f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}"
)
try: try:
self.mr.discussions.create({'body': body, 'position': pos_obj}) self.mr.discussions.create({'body': body, 'position': pos_obj})
except Exception as e: except Exception as e:
try: try:
# fallback - create a general note on the file in the MR # fallback - create a general note on the file in the MR
if 'suggestion_orig_location' in original_suggestion: if 'suggestion_orig_location' in original_suggestion:
line_start = original_suggestion['suggestion_orig_location']['start_line'] line_start = original_suggestion['suggestion_orig_location'][
line_end = original_suggestion['suggestion_orig_location']['end_line'] 'start_line'
]
line_end = original_suggestion['suggestion_orig_location'][
'end_line'
]
old_code_snippet = original_suggestion['prev_code_snippet'] old_code_snippet = original_suggestion['prev_code_snippet']
new_code_snippet = original_suggestion['new_code_snippet'] new_code_snippet = original_suggestion['new_code_snippet']
content = original_suggestion['suggestion_summary'] content = original_suggestion['suggestion_summary']
@ -287,19 +391,25 @@ class GitLabProvider(GitProvider):
else: else:
language = '' language = ''
link = self.get_line_link(relevant_file, line_start, line_end) link = self.get_line_link(relevant_file, line_start, line_end)
body_fallback =f"**Suggestion:** {content} [{label}, importance: {score}]\n\n" body_fallback = (
f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
)
body_fallback += f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n" body_fallback += f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`" body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`"
body_fallback += "</details>\n\n" body_fallback += "</details>\n\n"
diff_patch = difflib.unified_diff(old_code_snippet.split('\n'), diff_patch = difflib.unified_diff(
new_code_snippet.split('\n'), n=999) old_code_snippet.split('\n'),
new_code_snippet.split('\n'),
n=999,
)
patch_orig = "\n".join(diff_patch) patch_orig = "\n".join(diff_patch)
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```" diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
body_fallback += diff_code body_fallback += diff_code
# Create a general note on the file in the MR # Create a general note on the file in the MR
self.mr.notes.create({ self.mr.notes.create(
{
'body': body_fallback, 'body': body_fallback,
'position': { 'position': {
'base_sha': diff.base_commit_sha, 'base_sha': diff.base_commit_sha,
@ -307,16 +417,23 @@ class GitLabProvider(GitProvider):
'head_sha': diff.head_commit_sha, 'head_sha': diff.head_commit_sha,
'position_type': 'text', 'position_type': 'text',
'file_path': f'{target_file.filename}', 'file_path': f'{target_file.filename}',
},
} }
}) )
get_logger().debug(f"Created fallback comment in MR {self.id_mr} with position {pos_obj}") get_logger().debug(
f"Created fallback comment in MR {self.id_mr} with position {pos_obj}"
)
# get_logger().debug( # get_logger().debug(
# f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)") # f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)")
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to create comment in MR {self.id_mr}") get_logger().exception(
f"Failed to create comment in MR {self.id_mr}"
)
def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: str) -> Optional[dict]: def get_relevant_diff(
self, relevant_file: str, relevant_line_in_file: str
) -> Optional[dict]:
changes = self.mr.changes() # Retrieve the changes for the merge request once changes = self.mr.changes() # Retrieve the changes for the merge request once
if not changes: if not changes:
get_logger().error('No changes found for the merge request.') get_logger().error('No changes found for the merge request.')
@ -327,10 +444,14 @@ class GitLabProvider(GitProvider):
return None return None
for diff in all_diffs: for diff in all_diffs:
for change in changes['changes']: for change in changes['changes']:
if change['new_path'] == relevant_file and relevant_line_in_file in change['diff']: if (
change['new_path'] == relevant_file
and relevant_line_in_file in change['diff']
):
return diff return diff
get_logger().debug( get_logger().debug(
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.') f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.'
)
return self.last_diff # fallback to last_diff if no relevant diff is found return self.last_diff # fallback to last_diff if no relevant diff is found
def publish_code_suggestions(self, code_suggestions: list) -> bool: def publish_code_suggestions(self, code_suggestions: list) -> bool:
@ -365,10 +486,21 @@ class GitLabProvider(GitProvider):
found = True found = True
edit_type = 'addition' edit_type = 'addition'
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, self.send_inline_comment(
target_file, target_line_no, original_suggestion) body,
edit_type,
found,
relevant_file,
relevant_line_in_file,
source_line_no,
target_file,
target_line_no,
original_suggestion,
)
except Exception as e: except Exception as e:
get_logger().exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}") get_logger().exception(
f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}"
)
# note that we publish suggestions one-by-one. so, if one fails, the rest will still be published # note that we publish suggestions one-by-one. so, if one fails, the rest will still be published
return True return True
@ -382,8 +514,13 @@ class GitLabProvider(GitProvider):
edit_type = self.get_edit_type(relevant_line_in_file) edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.get_diff_files(): for file in self.get_diff_files():
if file.filename == relevant_file: if file.filename == relevant_file:
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file, (
relevant_line_in_file) edit_type,
found,
source_line_no,
target_file,
target_line_no,
) = self.find_in_file(file, relevant_line_in_file)
return edit_type, found, source_line_no, target_file, target_line_no return edit_type, found, source_line_no, target_file, target_line_no
def find_in_file(self, file, relevant_line_in_file): def find_in_file(self, file, relevant_line_in_file):
@ -414,7 +551,10 @@ class GitLabProvider(GitProvider):
found = True found = True
edit_type = self.get_edit_type(line) edit_type = self.get_edit_type(line)
break break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:].lstrip() in line: elif (
relevant_line_in_file[0] == '+'
and relevant_line_in_file[1:].lstrip() in line
):
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally # The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line # it's a context line
found = True found = True
@ -470,7 +610,11 @@ class GitLabProvider(GitProvider):
def get_repo_settings(self): def get_repo_settings(self):
try: try:
contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch).decode() contents = (
self.gl.projects.get(self.id_project)
.files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch)
.decode()
)
return contents return contents
except Exception: except Exception:
return "" return ""
@ -478,7 +622,9 @@ class GitLabProvider(GitProvider):
def get_workspace_name(self): def get_workspace_name(self):
return self.id_project.split('/')[0] return self.id_project.split('/')[0]
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: def add_eyes_reaction(
self, issue_comment_id: int, disable_eyes: bool = False
) -> Optional[int]:
return True return True
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
@ -489,7 +635,9 @@ class GitLabProvider(GitProvider):
path_parts = parsed_url.path.strip('/').split('/') path_parts = parsed_url.path.strip('/').split('/')
if 'merge_requests' not in path_parts: if 'merge_requests' not in path_parts:
raise ValueError("The provided URL does not appear to be a GitLab merge request URL") raise ValueError(
"The provided URL does not appear to be a GitLab merge request URL"
)
mr_index = path_parts.index('merge_requests') mr_index = path_parts.index('merge_requests')
# Ensure there is an ID after 'merge_requests' # Ensure there is an ID after 'merge_requests'
@ -541,8 +689,15 @@ class GitLabProvider(GitProvider):
""" """
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None) max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
try: try:
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list] commit_messages_list = [
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)]) commit['message'] for commit in self.mr.commits()._list
]
commit_messages_str = "\n".join(
[
f"{i + 1}. {message}"
for i, message in enumerate(commit_messages_list)
]
)
except Exception: except Exception:
commit_messages_str = "" commit_messages_str = ""
if max_tokens: if max_tokens:
@ -556,7 +711,12 @@ class GitLabProvider(GitProvider):
except: except:
return "" return ""
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: def get_line_link(
self,
relevant_file: str,
relevant_line_start: int,
relevant_line_end: int = None,
) -> str:
if relevant_line_start == -1: if relevant_line_start == -1:
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads" link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads"
elif relevant_line_end: elif relevant_line_end:
@ -565,7 +725,6 @@ class GitLabProvider(GitProvider):
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}" link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}"
return link return link
def generate_link_to_relevant_line_number(self, suggestion) -> str: def generate_link_to_relevant_line_number(self, suggestion) -> str:
try: try:
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip() relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
@ -573,8 +732,9 @@ class GitLabProvider(GitProvider):
if not relevant_line_str: if not relevant_line_str:
return "" return ""
position, absolute_position = find_line_number_of_relevant_line_in_file \ position, absolute_position = find_line_number_of_relevant_line_in_file(
(self.diff_files, relevant_file, relevant_line_str) self.diff_files, relevant_file, relevant_line_str
)
if absolute_position != -1: if absolute_position != -1:
# link to right file only # link to right file only

View File

@ -39,10 +39,16 @@ class LocalGitProvider(GitProvider):
self._prepare_repo() self._prepare_repo()
self.diff_files = None self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files()) self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
self.description_path = get_settings().get('local.description_path') \ self.description_path = (
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md' get_settings().get('local.description_path')
self.review_path = get_settings().get('local.review_path') \ if get_settings().get('local.description_path') is not None
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md' else self.repo_path / 'description.md'
)
self.review_path = (
get_settings().get('local.review_path')
if get_settings().get('local.review_path') is not None
else self.repo_path / 'review.md'
)
# inline code comments are not supported for local git repositories # inline code comments are not supported for local git repositories
get_settings().pr_reviewer.inline_code_comments = False get_settings().pr_reviewer.inline_code_comments = False
@ -52,30 +58,43 @@ class LocalGitProvider(GitProvider):
""" """
get_logger().debug('Preparing repository for PR-mimic generation...') get_logger().debug('Preparing repository for PR-mimic generation...')
if self.repo.is_dirty(): if self.repo.is_dirty():
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.') raise ValueError(
'The repository is not in a clean state. Please commit or stash pending changes.'
)
if self.target_branch_name not in self.repo.heads: if self.target_branch_name not in self.repo.heads:
raise KeyError(f'Branch: {self.target_branch_name} does not exist') raise KeyError(f'Branch: {self.target_branch_name} does not exist')
def is_supported(self, capability: str) -> bool: def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels', if capability in [
'gfm_markdown']: 'get_issue_comments',
'create_inline_comment',
'publish_inline_comments',
'get_labels',
'gfm_markdown',
]:
return False return False
return True return True
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.repo.head.commit.diff( diffs = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]), self.repo.merge_base(
self.repo.head, self.repo.branches[self.target_branch_name]
),
create_patch=True, create_patch=True,
R=True R=True,
) )
diff_files = [] diff_files = []
for diff_item in diffs: for diff_item in diffs:
if diff_item.a_blob is not None: if diff_item.a_blob is not None:
original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8') original_file_content_str = diff_item.a_blob.data_stream.read().decode(
'utf-8'
)
else: else:
original_file_content_str = "" # empty file original_file_content_str = "" # empty file
if diff_item.b_blob is not None: if diff_item.b_blob is not None:
new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8') new_file_content_str = diff_item.b_blob.data_stream.read().decode(
'utf-8'
)
else: else:
new_file_content_str = "" # empty file new_file_content_str = "" # empty file
edit_type = EDIT_TYPE.MODIFIED edit_type = EDIT_TYPE.MODIFIED
@ -86,12 +105,15 @@ class LocalGitProvider(GitProvider):
elif diff_item.renamed_file: elif diff_item.renamed_file:
edit_type = EDIT_TYPE.RENAMED edit_type = EDIT_TYPE.RENAMED
diff_files.append( diff_files.append(
FilePatchInfo(original_file_content_str, FilePatchInfo(
original_file_content_str,
new_file_content_str, new_file_content_str,
diff_item.diff.decode('utf-8'), diff_item.diff.decode('utf-8'),
diff_item.b_path, diff_item.b_path,
edit_type=edit_type, edit_type=edit_type,
old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path old_filename=None
if diff_item.a_path == diff_item.b_path
else diff_item.a_path,
) )
) )
self.diff_files = diff_files self.diff_files = diff_files
@ -102,8 +124,10 @@ class LocalGitProvider(GitProvider):
Returns a list of files with changes in the diff. Returns a list of files with changes in the diff.
""" """
diff_index = self.repo.head.commit.diff( diff_index = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]), self.repo.merge_base(
R=True self.repo.head, self.repo.branches[self.target_branch_name]
),
R=True,
) )
# Get the list of changed files # Get the list of changed files
diff_files = [item.a_path for item in diff_index] diff_files = [item.a_path for item in diff_index]
@ -119,18 +143,37 @@ class LocalGitProvider(GitProvider):
# Write the string to the file # Write the string to the file
file.write(pr_comment) file.write(pr_comment)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): def publish_inline_comment(
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider') self,
body: str,
relevant_file: str,
relevant_line_in_file: str,
original_suggestion=None,
):
raise NotImplementedError(
'Publishing inline comments is not implemented for the local git provider'
)
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider') raise NotImplementedError(
'Publishing inline comments is not implemented for the local git provider'
)
def publish_code_suggestion(self, body: str, relevant_file: str, def publish_code_suggestion(
relevant_lines_start: int, relevant_lines_end: int): self,
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider') body: str,
relevant_file: str,
relevant_lines_start: int,
relevant_lines_end: int,
):
raise NotImplementedError(
'Publishing code suggestions is not implemented for the local git provider'
)
def publish_code_suggestions(self, code_suggestions: list) -> bool: def publish_code_suggestions(self, code_suggestions: list) -> bool:
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider') raise NotImplementedError(
'Publishing code suggestions is not implemented for the local git provider'
)
def publish_labels(self, labels): def publish_labels(self, labels):
pass # Not applicable to the local git provider, but required by the interface pass # Not applicable to the local git provider, but required by the interface
@ -158,19 +201,31 @@ class LocalGitProvider(GitProvider):
Calculate percentage of languages in repository. Used for hunk prioritisation. Calculate percentage of languages in repository. Used for hunk prioritisation.
""" """
# Get all files in repository # Get all files in repository
filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob'] filepaths = [
Path(item.path)
for item in self.repo.tree().traverse()
if item.type == 'blob'
]
# Identify language by file extension and count # Identify language by file extension and count
lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()]) lang_count = Counter(
ext.lstrip('.')
for filepath in filepaths
for ext in [filepath.suffix.lower()]
)
# Convert counts to percentages # Convert counts to percentages
total_files = len(filepaths) total_files = len(filepaths)
lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()} lang_percentage = {
lang: count / total_files * 100 for lang, count in lang_count.items()
}
return lang_percentage return lang_percentage
def get_pr_branch(self): def get_pr_branch(self):
return self.repo.head return self.repo.head
def get_user_id(self): def get_user_id(self):
return -1 # Not used anywhere for the local provider, but required by the interface return (
-1
) # Not used anywhere for the local provider, but required by the interface
def get_pr_description_full(self): def get_pr_description_full(self):
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD')) commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
@ -186,7 +241,11 @@ class LocalGitProvider(GitProvider):
return self.head_branch_name return self.head_branch_name
def get_issue_comments(self): def get_issue_comments(self):
raise NotImplementedError('Getting issue comments is not implemented for the local git provider') raise NotImplementedError(
'Getting issue comments is not implemented for the local git provider'
)
def get_pr_labels(self, update=False): def get_pr_labels(self, update=False):
raise NotImplementedError('Getting labels is not implemented for the local git provider') raise NotImplementedError(
'Getting labels is not implemented for the local git provider'
)

View File

@ -6,7 +6,7 @@ from dynaconf import Dynaconf
from starlette_context import context from starlette_context import context
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (get_git_provider_with_context) from utils.pr_agent.git_providers import get_git_provider_with_context
from utils.pr_agent.log import get_logger from utils.pr_agent.log import get_logger
@ -20,7 +20,9 @@ def apply_repo_settings(pr_url):
except Exception: except Exception:
repo_settings = None repo_settings = None
pass pass
if repo_settings is None: # None is different from "", which is a valid value if (
repo_settings is None
): # None is different from "", which is a valid value
repo_settings = git_provider.get_repo_settings() repo_settings = git_provider.get_repo_settings()
try: try:
context["repo_settings"] = repo_settings context["repo_settings"] = repo_settings
@ -36,15 +38,25 @@ def apply_repo_settings(pr_url):
os.write(fd, repo_settings) os.write(fd, repo_settings)
new_settings = Dynaconf(settings_files=[repo_settings_file]) new_settings = Dynaconf(settings_files=[repo_settings_file])
for section, contents in new_settings.as_dict().items(): for section, contents in new_settings.as_dict().items():
section_dict = copy.deepcopy(get_settings().as_dict().get(section, {})) section_dict = copy.deepcopy(
get_settings().as_dict().get(section, {})
)
for key, value in contents.items(): for key, value in contents.items():
section_dict[key] = value section_dict[key] = value
get_settings().unset(section) get_settings().unset(section)
get_settings().set(section, section_dict, merge=False) get_settings().set(section, section_dict, merge=False)
get_logger().info(f"Applying repo settings:\n{new_settings.as_dict()}") get_logger().info(
f"Applying repo settings:\n{new_settings.as_dict()}"
)
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to apply repo {category} settings, error: {str(e)}") get_logger().warning(
error_local = {'error': str(e), 'settings': repo_settings, 'category': category} f"Failed to apply repo {category} settings, error: {str(e)}"
)
error_local = {
'error': str(e),
'settings': repo_settings,
'category': category,
}
if error_local: if error_local:
handle_configurations_errors([error_local], git_provider) handle_configurations_errors([error_local], git_provider)
@ -55,7 +67,10 @@ def apply_repo_settings(pr_url):
try: try:
os.remove(repo_settings_file) os.remove(repo_settings_file)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e) get_logger().error(
f"Failed to remove temporary settings file {repo_settings_file}",
e,
)
# enable switching models with a short definition # enable switching models with a short definition
if get_settings().config.model.lower() == 'claude-3-5-sonnet': if get_settings().config.model.lower() == 'claude-3-5-sonnet':
@ -79,13 +94,18 @@ def handle_configurations_errors(config_errors, git_provider):
body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>" body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>"
else: else:
body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n" body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n"
get_logger().warning(f"Sending a 'configuration error' comment to the PR", artifact={'body': body}) get_logger().warning(
f"Sending a 'configuration error' comment to the PR",
artifact={'body': body},
)
# git_provider.publish_comment(body) # git_provider.publish_comment(body)
if hasattr(git_provider, 'publish_persistent_comment'): if hasattr(git_provider, 'publish_persistent_comment'):
git_provider.publish_persistent_comment(body, git_provider.publish_persistent_comment(
body,
initial_header=header, initial_header=header,
update_header=False, update_header=False,
final_update_message=False) final_update_message=False,
)
else: else:
git_provider.publish_comment(body) git_provider.publish_comment(body)
except Exception as e: except Exception as e:

View File

@ -1,10 +1,9 @@
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.identity_providers.default_identity_provider import \ from utils.pr_agent.identity_providers.default_identity_provider import (
DefaultIdentityProvider DefaultIdentityProvider,
)
_IDENTITY_PROVIDERS = { _IDENTITY_PROVIDERS = {'default': DefaultIdentityProvider}
'default': DefaultIdentityProvider
}
def get_identity_provider(): def get_identity_provider():

View File

@ -1,5 +1,7 @@
from utils.pr_agent.identity_providers.identity_provider import (Eligibility, from utils.pr_agent.identity_providers.identity_provider import (
IdentityProvider) Eligibility,
IdentityProvider,
)
class DefaultIdentityProvider(IdentityProvider): class DefaultIdentityProvider(IdentityProvider):

View File

@ -30,7 +30,9 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
if type(level) is not int: if type(level) is not int:
level = logging.INFO level = logging.INFO
if fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0": # better debugging github_app if (
fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0"
): # better debugging github_app
logger.remove(None) logger.remove(None)
logger.add( logger.add(
sys.stdout, sys.stdout,

View File

@ -8,10 +8,14 @@ def get_secret_provider():
provider_id = get_settings().config.secret_provider provider_id = get_settings().config.secret_provider
if provider_id == 'google_cloud_storage': if provider_id == 'google_cloud_storage':
try: try:
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import \ from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import (
GoogleCloudStorageSecretProvider GoogleCloudStorageSecretProvider,
)
return GoogleCloudStorageSecretProvider() return GoogleCloudStorageSecretProvider()
except Exception as e: except Exception as e:
raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e raise ValueError(
f"Failed to initialize google_cloud_storage secret provider {provider_id}"
) from e
else: else:
raise ValueError("Unknown SECRET_PROVIDER") raise ValueError("Unknown SECRET_PROVIDER")

View File

@ -9,12 +9,15 @@ from utils.pr_agent.secret_providers.secret_provider import SecretProvider
class GoogleCloudStorageSecretProvider(SecretProvider): class GoogleCloudStorageSecretProvider(SecretProvider):
def __init__(self): def __init__(self):
try: try:
self.client = storage.Client.from_service_account_info(ujson.loads(get_settings().google_cloud_storage. self.client = storage.Client.from_service_account_info(
service_account)) ujson.loads(get_settings().google_cloud_storage.service_account)
)
self.bucket_name = get_settings().google_cloud_storage.bucket_name self.bucket_name = get_settings().google_cloud_storage.bucket_name
self.bucket = self.client.bucket(self.bucket_name) self.bucket = self.client.bucket(self.bucket_name)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to initialize Google Cloud Storage Secret Provider: {e}") get_logger().error(
f"Failed to initialize Google Cloud Storage Secret Provider: {e}"
)
raise e raise e
def get_secret(self, secret_name: str) -> str: def get_secret(self, secret_name: str) -> str:
@ -22,7 +25,9 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
blob = self.bucket.blob(secret_name) blob = self.bucket.blob(secret_name)
return blob.download_as_string() return blob.download_as_string()
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to get secret {secret_name} from Google Cloud Storage: {e}") get_logger().warning(
f"Failed to get secret {secret_name} from Google Cloud Storage: {e}"
)
return "" return ""
def store_secret(self, secret_name: str, secret_value: str): def store_secret(self, secret_name: str, secret_value: str):
@ -30,5 +35,7 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
blob = self.bucket.blob(secret_name) blob = self.bucket.blob(secret_name)
blob.upload_from_string(secret_value) blob.upload_from_string(secret_value)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to store secret {secret_name} in Google Cloud Storage: {e}") get_logger().error(
f"Failed to store secret {secret_name} in Google Cloud Storage: {e}"
)
raise e raise e

View File

@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
class SecretProvider(ABC): class SecretProvider(ABC):
@abstractmethod @abstractmethod
def get_secret(self, secret_name: str) -> str: def get_secret(self, secret_name: str) -> str:
pass pass

View File

@ -33,6 +33,7 @@ azure_devops_server = get_settings().get("azure_devops_server")
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username") WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password") WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
def handle_request( def handle_request(
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
): ):
@ -62,10 +63,17 @@ def authorize(credentials: HTTPBasicCredentials = Depends(security)):
) )
async def _perform_commands_azure(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict): async def _perform_commands_azure(
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict
):
apply_repo_settings(api_url) apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled if (
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context) commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
**log_context,
)
return return
commands = get_settings().get(f"azure_devops_server.{commands_conf}") commands = get_settings().get(f"azure_devops_server.{commands_conf}")
get_settings().set("config.is_auto_command", True) get_settings().set("config.is_auto_command", True)
@ -92,22 +100,38 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
actions = [] actions = []
if data["eventType"] == "git.pullrequest.created": if data["eventType"] == "git.pullrequest.created":
# API V1 (latest) # API V1 (latest)
pr_url = unquote(data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git")) pr_url = unquote(
data["resource"]["_links"]["web"]["href"].replace(
"_apis/git/repositories", "_git"
)
)
log_context["event"] = data["eventType"] log_context["event"] = data["eventType"]
log_context["api_url"] = pr_url log_context["api_url"] = pr_url
await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context) await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context)
return return
elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]: elif (
data["eventType"] == "ms.vss-code.git-pullrequest-comment-event"
and "content" in data["resource"]["comment"]
):
if available_commands_rgx.match(data["resource"]["comment"]["content"]): if available_commands_rgx.match(data["resource"]["comment"]["content"]):
if(data["resourceVersion"] == "2.0"): if data["resourceVersion"] == "2.0":
repo = data["resource"]["pullRequest"]["repository"]["webUrl"] repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
pr_url = unquote(f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}') pr_url = unquote(
f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}'
)
actions = [data["resource"]["comment"]["content"]] actions = [data["resource"]["comment"]["content"]]
else: else:
# API V1 not supported as it does not contain the PR URL # API V1 not supported as it does not contain the PR URL
return JSONResponse( return (
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})), content=json.dumps(
{
"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"
}
),
),
)
else: else:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -132,17 +156,21 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
content=json.dumps({"message": "Internal server error"}), content=json.dumps({"message": "Internal server error"}),
) )
return JSONResponse( return JSONResponse(
status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggered successfully"}) status_code=status.HTTP_202_ACCEPTED,
content=jsonable_encoder({"message": "webhook triggered successfully"}),
) )
@router.get("/") @router.get("/")
async def root(): async def root():
return {"status": "ok"} return {"status": "ok"}
def start(): def start():
app = FastAPI(middleware=[Middleware(RawContextMiddleware)]) app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
app.include_router(router) app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000"))) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
if __name__ == "__main__": if __name__ == "__main__":
start() start()

View File

@ -27,7 +27,9 @@ from utils.pr_agent.secret_providers import get_secret_provider
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG") setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
router = APIRouter() router = APIRouter()
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None secret_provider = (
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
)
async def get_bearer_token(shared_secret: str, client_key: str): async def get_bearer_token(shared_secret: str, client_key: str):
@ -49,7 +51,7 @@ async def get_bearer_token(shared_secret: str, client_key: str):
payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt' payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt'
headers = { headers = {
'Authorization': f'JWT {token}', 'Authorization': f'JWT {token}',
'Content-Type': 'application/x-www-form-urlencoded' 'Content-Type': 'application/x-www-form-urlencoded',
} }
response = requests.request("POST", url, headers=headers, data=payload) response = requests.request("POST", url, headers=headers, data=payload)
bearer_token = response.json()["access_token"] bearer_token = response.json()["access_token"]
@ -58,6 +60,7 @@ async def get_bearer_token(shared_secret: str, client_key: str):
get_logger().error(f"Failed to get bearer token: {e}") get_logger().error(f"Failed to get bearer token: {e}")
raise e raise e
@router.get("/") @router.get("/")
async def handle_manifest(request: Request, response: Response): async def handle_manifest(request: Request, response: Response):
cur_dir = os.path.dirname(os.path.abspath(__file__)) cur_dir = os.path.dirname(os.path.abspath(__file__))
@ -66,7 +69,9 @@ async def handle_manifest(request: Request, response: Response):
manifest = manifest.replace("app_key", get_settings().bitbucket.app_key) manifest = manifest.replace("app_key", get_settings().bitbucket.app_key)
manifest = manifest.replace("base_url", get_settings().bitbucket.base_url) manifest = manifest.replace("base_url", get_settings().bitbucket.base_url)
except: except:
get_logger().error("Failed to replace api_key in Bitbucket manifest, trying to continue") get_logger().error(
"Failed to replace api_key in Bitbucket manifest, trying to continue"
)
manifest_obj = json.loads(manifest) manifest_obj = json.loads(manifest)
return JSONResponse(manifest_obj) return JSONResponse(manifest_obj)
@ -83,10 +88,16 @@ def _get_username(data):
return "" return ""
async def _perform_commands_bitbucket(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict): async def _perform_commands_bitbucket(
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
):
apply_repo_settings(api_url) apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled if (
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}") commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
)
return return
if data.get("event", "") == "pullrequest:created": if data.get("event", "") == "pullrequest:created":
if not should_process_pr_logic(data): if not should_process_pr_logic(data):
@ -132,7 +143,9 @@ def should_process_pr_logic(data) -> bool:
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", []) ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender: if ignore_pr_users and sender:
if sender in ignore_pr_users: if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting") get_logger().info(
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
)
return False return False
# logic to ignore PRs with specific titles # logic to ignore PRs with specific titles
@ -140,20 +153,34 @@ def should_process_pr_logic(data) -> bool:
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", []) ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
if not isinstance(ignore_pr_title_re, list): if not isinstance(ignore_pr_title_re, list):
ignore_pr_title_re = [ignore_pr_title_re] ignore_pr_title_re = [ignore_pr_title_re]
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re): if ignore_pr_title_re and any(
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting") re.search(regex, title) for regex in ignore_pr_title_re
):
get_logger().info(
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
)
return False return False
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", []) ignore_pr_source_branches = get_settings().get(
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", []) "CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
if (ignore_pr_source_branches or ignore_pr_target_branches): )
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches): ignore_pr_target_branches = get_settings().get(
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
)
if ignore_pr_source_branches or ignore_pr_target_branches:
if any(
re.search(regex, source_branch) for regex in ignore_pr_source_branches
):
get_logger().info( get_logger().info(
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings") f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
)
return False return False
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches): if any(
re.search(regex, target_branch) for regex in ignore_pr_target_branches
):
get_logger().info( get_logger().info(
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings") f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
)
return False return False
except Exception as e: except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}") get_logger().error(f"Failed 'should_process_pr_logic': {e}")
@ -195,7 +222,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
client_key = claims["iss"] client_key = claims["iss"]
secrets = json.loads(secret_provider.get_secret(client_key)) secrets = json.loads(secret_provider.get_secret(client_key))
shared_secret = secrets["shared_secret"] shared_secret = secrets["shared_secret"]
jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"]) jwt.decode(
input_jwt, shared_secret, audience=client_key, algorithms=["HS256"]
)
bearer_token = await get_bearer_token(shared_secret, client_key) bearer_token = await get_bearer_token(shared_secret, client_key)
context['bitbucket_bearer_token'] = bearer_token context['bitbucket_bearer_token'] = bearer_token
context["settings"] = copy.deepcopy(global_settings) context["settings"] = copy.deepcopy(global_settings)
@ -208,28 +237,41 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
if pr_url: if pr_url:
with get_logger().contextualize(**log_context): with get_logger().contextualize(**log_context):
apply_repo_settings(pr_url) apply_repo_settings(pr_url)
if get_identity_provider().verify_eligibility("bitbucket", if (
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE: get_identity_provider().verify_eligibility(
"bitbucket", sender_id, pr_url
)
is not Eligibility.NOT_ELIGIBLE
):
if get_settings().get("bitbucket_app.pr_commands"): if get_settings().get("bitbucket_app.pr_commands"):
await _perform_commands_bitbucket("pr_commands", PRAgent(), pr_url, log_context, data) await _perform_commands_bitbucket(
"pr_commands", PRAgent(), pr_url, log_context, data
)
elif event == "pullrequest:comment_created": elif event == "pullrequest:comment_created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url log_context["api_url"] = pr_url
log_context["event"] = "comment" log_context["event"] = "comment"
comment_body = data["data"]["comment"]["content"]["raw"] comment_body = data["data"]["comment"]["content"]["raw"]
with get_logger().contextualize(**log_context): with get_logger().contextualize(**log_context):
if get_identity_provider().verify_eligibility("bitbucket", if (
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE: get_identity_provider().verify_eligibility(
"bitbucket", sender_id, pr_url
)
is not Eligibility.NOT_ELIGIBLE
):
await agent.handle_request(pr_url, comment_body) await agent.handle_request(pr_url, comment_body)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to handle webhook: {e}") get_logger().error(f"Failed to handle webhook: {e}")
background_tasks.add_task(inner) background_tasks.add_task(inner)
return "OK" return "OK"
@router.get("/webhook") @router.get("/webhook")
async def handle_github_webhooks(request: Request, response: Response): async def handle_github_webhooks(request: Request, response: Response):
return "Webhook server online!" return "Webhook server online!"
@router.post("/installed") @router.post("/installed")
async def handle_installed_webhooks(request: Request, response: Response): async def handle_installed_webhooks(request: Request, response: Response):
try: try:
@ -240,15 +282,13 @@ async def handle_installed_webhooks(request: Request, response: Response):
shared_secret = data["sharedSecret"] shared_secret = data["sharedSecret"]
client_key = data["clientKey"] client_key = data["clientKey"]
username = data["principal"]["username"] username = data["principal"]["username"]
secrets = { secrets = {"shared_secret": shared_secret, "client_key": client_key}
"shared_secret": shared_secret,
"client_key": client_key
}
secret_provider.store_secret(username, json.dumps(secrets)) secret_provider.store_secret(username, json.dumps(secrets))
except Exception as e: except Exception as e:
get_logger().error(f"Failed to register user: {e}") get_logger().error(f"Failed to register user: {e}")
return JSONResponse({"error": "Unable to register user"}, status_code=500) return JSONResponse({"error": "Unable to register user"}, status_code=500)
@router.post("/uninstalled") @router.post("/uninstalled")
async def handle_uninstalled_webhooks(request: Request, response: Response): async def handle_uninstalled_webhooks(request: Request, response: Response):
get_logger().info("handle_uninstalled_webhooks") get_logger().info("handle_uninstalled_webhooks")

View File

@ -40,10 +40,12 @@ def handle_request(
background_tasks.add_task(inner) background_tasks.add_task(inner)
@router.post("/") @router.post("/")
async def redirect_to_webhook(): async def redirect_to_webhook():
return RedirectResponse(url="/webhook") return RedirectResponse(url="/webhook")
@router.post("/webhook") @router.post("/webhook")
async def handle_webhook(background_tasks: BackgroundTasks, request: Request): async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
log_context = {"server_type": "bitbucket_server"} log_context = {"server_type": "bitbucket_server"}
@ -55,7 +57,8 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
body_bytes = await request.body() body_bytes = await request.body()
if body_bytes.decode('utf-8') == '{"test": true}': if body_bytes.decode('utf-8') == '{"test": true}':
return JSONResponse( return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "connection test successful"}) status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "connection test successful"}),
) )
signature_header = request.headers.get("x-hub-signature", None) signature_header = request.headers.get("x-hub-signature", None)
verify_signature(body_bytes, webhook_secret, signature_header) verify_signature(body_bytes, webhook_secret, signature_header)
@ -73,11 +76,18 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
if data["eventKey"] == "pr:opened": if data["eventKey"] == "pr:opened":
apply_repo_settings(pr_url) apply_repo_settings(pr_url)
if get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled if (
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", **log_context) get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {pr_url}",
**log_context,
)
return return
get_settings().set("config.is_auto_command", True) get_settings().set("config.is_auto_command", True)
commands_to_run.extend(_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS')) commands_to_run.extend(
_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS')
)
elif data["eventKey"] == "pr:comment:added": elif data["eventKey"] == "pr:comment:added":
commands_to_run.append(data["comment"]["text"]) commands_to_run.append(data["comment"]["text"])
else: else:
@ -116,6 +126,7 @@ async def _run_commands_sequentially(commands: List[str], url: str, log_context:
except Exception as e: except Exception as e:
get_logger().error(f"Failed to handle command: {command} , error: {e}") get_logger().error(f"Failed to handle command: {command} , error: {e}")
def _process_command(command: str, url) -> str: def _process_command(command: str, url) -> str:
# don't think we need this # don't think we need this
apply_repo_settings(url) apply_repo_settings(url)
@ -146,7 +157,9 @@ def _get_commands_list_from_settings(setting_key:str ) -> list:
try: try:
return get_settings().get(setting_key, []) return get_settings().get(setting_key, [])
except ValueError as e: except ValueError as e:
get_logger().error(f"Failed to get commands list from settings {setting_key}: {e}") get_logger().error(
f"Failed to get commands list from settings {setting_key}: {e}"
)
@router.get("/") @router.get("/")

View File

@ -40,12 +40,10 @@ async def handle_gerrit_request(action: Action, item: Item):
if action == Action.ask: if action == Action.ask:
if not item.msg: if not item.msg:
return HTTPException( return HTTPException(
status_code=400, status_code=400, detail="msg is required for ask command"
detail="msg is required for ask command"
) )
await PRAgent().handle_request( await PRAgent().handle_request(
f"{item.project}:{item.refspec}", f"{item.project}:{item.refspec}", f"/{item.msg.strip()}"
f"/{item.msg.strip()}"
) )

View File

@ -26,7 +26,12 @@ def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str,
try: try:
value = get_settings().get(key, default) value = get_settings().get(key, default)
except AttributeError: # TBD still need to debug why this happens on GitHub Actions except AttributeError: # TBD still need to debug why this happens on GitHub Actions
value = os.getenv(key, None) or os.getenv(key.upper(), None) or os.getenv(key.lower(), None) or default value = (
os.getenv(key, None)
or os.getenv(key.upper(), None)
or os.getenv(key.lower(), None)
or default
)
return value return value
@ -76,16 +81,24 @@ async def run_action():
pr_url = event_payload.get("pull_request", {}).get("html_url") pr_url = event_payload.get("pull_request", {}).get("html_url")
if pr_url: if pr_url:
apply_repo_settings(pr_url) apply_repo_settings(pr_url)
get_logger().info(f"enable_custom_labels: {get_settings().config.enable_custom_labels}") get_logger().info(
f"enable_custom_labels: {get_settings().config.enable_custom_labels}"
)
except Exception as e: except Exception as e:
get_logger().info(f"github action: failed to apply repo settings: {e}") get_logger().info(f"github action: failed to apply repo settings: {e}")
# Handle pull request opened event # Handle pull request opened event
if GITHUB_EVENT_NAME == "pull_request" or GITHUB_EVENT_NAME == "pull_request_target": if (
GITHUB_EVENT_NAME == "pull_request"
or GITHUB_EVENT_NAME == "pull_request_target"
):
action = event_payload.get("action") action = event_payload.get("action")
# Retrieve the list of actions from the configuration # Retrieve the list of actions from the configuration
pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"]) pr_actions = get_settings().get(
"GITHUB_ACTION_CONFIG.PR_ACTIONS",
["opened", "reopened", "ready_for_review", "review_requested"],
)
if action in pr_actions: if action in pr_actions:
pr_url = event_payload.get("pull_request", {}).get("url") pr_url = event_payload.get("pull_request", {}).get("url")
@ -93,18 +106,30 @@ async def run_action():
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG # legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None) auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
if auto_review is None: if auto_review is None:
auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None) auto_review = get_setting_or_env(
"GITHUB_ACTION_CONFIG.AUTO_REVIEW", None
)
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None) auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
if auto_describe is None: if auto_describe is None:
auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None) auto_describe = get_setting_or_env(
"GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None
)
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None) auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
if auto_improve is None: if auto_improve is None:
auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None) auto_improve = get_setting_or_env(
"GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None
)
# Set the configuration for auto actions # Set the configuration for auto actions
get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto get_settings().config.is_auto_command = (
get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled True # Set the flag to indicate that the command is auto
get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}") )
get_settings().pr_description.final_update_message = (
False # No final update message when auto_describe is enabled
)
get_logger().info(
f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}"
)
# invoke by default all three tools # invoke by default all three tools
if auto_describe is None or is_true(auto_describe): if auto_describe is None or is_true(auto_describe):
@ -117,7 +142,10 @@ async def run_action():
get_logger().info(f"Skipping action: {action}") get_logger().info(f"Skipping action: {action}")
# Handle issue comment event # Handle issue comment event
elif GITHUB_EVENT_NAME == "issue_comment" or GITHUB_EVENT_NAME == "pull_request_review_comment": elif (
GITHUB_EVENT_NAME == "issue_comment"
or GITHUB_EVENT_NAME == "pull_request_review_comment"
):
action = event_payload.get("action") action = event_payload.get("action")
if action in ["created", "edited"]: if action in ["created", "edited"]:
comment_body = event_payload.get("comment", {}).get("body") comment_body = event_payload.get("comment", {}).get("body")
@ -133,9 +161,15 @@ async def run_action():
disable_eyes = False disable_eyes = False
# check if issue is pull request # check if issue is pull request
if event_payload.get("issue", {}).get("pull_request"): if event_payload.get("issue", {}).get("pull_request"):
url = event_payload.get("issue", {}).get("pull_request", {}).get("url") url = (
event_payload.get("issue", {})
.get("pull_request", {})
.get("url")
)
is_pr = True is_pr = True
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment elif event_payload.get("comment", {}).get(
"pull_request_url"
): # for 'pull_request_review_comment
url = event_payload.get("comment", {}).get("pull_request_url") url = event_payload.get("comment", {}).get("pull_request_url")
is_pr = True is_pr = True
disable_eyes = True disable_eyes = True
@ -148,9 +182,11 @@ async def run_action():
provider = get_git_provider()(pr_url=url) provider = get_git_provider()(pr_url=url)
if is_pr: if is_pr:
await PRAgent().handle_request( await PRAgent().handle_request(
url, body, notify=lambda: provider.add_eyes_reaction( url,
body,
notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes comment_id, disable_eyes=disable_eyes
) ),
) )
else: else:
await PRAgent().handle_request(url, body) await PRAgent().handle_request(url, body)

View File

@ -15,8 +15,7 @@ from starlette_context.middleware import RawContextMiddleware
from utils.pr_agent.agent.pr_agent import PRAgent from utils.pr_agent.agent.pr_agent import PRAgent
from utils.pr_agent.algo.utils import update_settings_from_args from utils.pr_agent.algo.utils import update_settings_from_args
from utils.pr_agent.config_loader import get_settings, global_settings from utils.pr_agent.config_loader import get_settings, global_settings
from utils.pr_agent.git_providers import (get_git_provider, from utils.pr_agent.git_providers import get_git_provider, get_git_provider_with_context
get_git_provider_with_context)
from utils.pr_agent.git_providers.utils import apply_repo_settings from utils.pr_agent.git_providers.utils import apply_repo_settings
from utils.pr_agent.identity_providers import get_identity_provider from utils.pr_agent.identity_providers import get_identity_provider
from utils.pr_agent.identity_providers.identity_provider import Eligibility from utils.pr_agent.identity_providers.identity_provider import Eligibility
@ -35,7 +34,9 @@ router = APIRouter()
@router.post("/api/v1/github_webhooks") @router.post("/api/v1/github_webhooks")
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request, response: Response): async def handle_github_webhooks(
background_tasks: BackgroundTasks, request: Request, response: Response
):
""" """
Receives and processes incoming GitHub webhook requests. Receives and processes incoming GitHub webhook requests.
Verifies the request signature, parses the request body, and passes it to the handle_request function for further Verifies the request signature, parses the request body, and passes it to the handle_request function for further
@ -49,7 +50,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
context["installation_id"] = installation_id context["installation_id"] = installation_id
context["settings"] = copy.deepcopy(global_settings) context["settings"] = copy.deepcopy(global_settings)
context["git_provider"] = {} context["git_provider"] = {}
background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None)) background_tasks.add_task(
handle_request, body, event=request.headers.get("X-GitHub-Event", None)
)
return {} return {}
@ -73,35 +76,61 @@ async def get_body(request):
return body return body
_duplicate_push_triggers = DefaultDictWithTimeout(ttl=get_settings().github_app.push_trigger_pending_tasks_ttl) _duplicate_push_triggers = DefaultDictWithTimeout(
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.Condition, ttl=get_settings().github_app.push_trigger_pending_tasks_ttl) ttl=get_settings().github_app.push_trigger_pending_tasks_ttl
)
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(
asyncio.locks.Condition,
ttl=get_settings().github_app.push_trigger_pending_tasks_ttl,
)
async def handle_comments_on_pr(body: Dict[str, Any],
async def handle_comments_on_pr(
body: Dict[str, Any],
event: str, event: str,
sender: str, sender: str,
sender_id: str, sender_id: str,
action: str, action: str,
log_context: Dict[str, Any], log_context: Dict[str, Any],
agent: PRAgent): agent: PRAgent,
):
if "comment" not in body: if "comment" not in body:
return {} return {}
comment_body = body.get("comment", {}).get("body") comment_body = body.get("comment", {}).get("body")
if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"): if (
comment_body
and isinstance(comment_body, str)
and not comment_body.lstrip().startswith("/")
):
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'): if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
comment_body_split = comment_body.split('/ask') comment_body_split = comment_body.split('/ask')
comment_body = '/ask' + comment_body_split[1] +' \n' +comment_body_split[0].strip().lstrip('>') comment_body = (
get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}") '/ask'
+ comment_body_split[1]
+ ' \n'
+ comment_body_split[0].strip().lstrip('>')
)
get_logger().info(
f"Reformatting comment_body so command is at the beginning: {comment_body}"
)
else: else:
get_logger().info("Ignoring comment not starting with /") get_logger().info("Ignoring comment not starting with /")
return {} return {}
disable_eyes = False disable_eyes = False
if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]: if (
"issue" in body
and "pull_request" in body["issue"]
and "url" in body["issue"]["pull_request"]
):
api_url = body["issue"]["pull_request"]["url"] api_url = body["issue"]["pull_request"]["url"]
elif "comment" in body and "pull_request_url" in body["comment"]: elif "comment" in body and "pull_request_url" in body["comment"]:
api_url = body["comment"]["pull_request_url"] api_url = body["comment"]["pull_request_url"]
try: try:
if ('/ask' in comment_body and if (
'subject_type' in body["comment"] and body["comment"]["subject_type"] == "line"): '/ask' in comment_body
and 'subject_type' in body["comment"]
and body["comment"]["subject_type"] == "line"
):
# comment on a code line in the "files changed" tab # comment on a code line in the "files changed" tab
comment_body = handle_line_comments(body, comment_body) comment_body = handle_line_comments(body, comment_body)
disable_eyes = True disable_eyes = True
@ -113,46 +142,75 @@ async def handle_comments_on_pr(body: Dict[str, Any],
comment_id = body.get("comment", {}).get("id") comment_id = body.get("comment", {}).get("id")
provider = get_git_provider_with_context(pr_url=api_url) provider = get_git_provider_with_context(pr_url=api_url)
with get_logger().contextualize(**log_context): with get_logger().contextualize(**log_context):
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: if (
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}") get_identity_provider().verify_eligibility("github", sender_id, api_url)
await agent.handle_request(api_url, comment_body, is not Eligibility.NOT_ELIGIBLE
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes)) ):
get_logger().info(
f"Processing comment on PR {api_url=}, comment_body={comment_body}"
)
await agent.handle_request(
api_url,
comment_body,
notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes
),
)
else: else:
get_logger().info(f"User {sender=} is not eligible to process comment on PR {api_url=}") get_logger().info(
f"User {sender=} is not eligible to process comment on PR {api_url=}"
)
async def handle_new_pr_opened(body: Dict[str, Any],
async def handle_new_pr_opened(
body: Dict[str, Any],
event: str, event: str,
sender: str, sender: str,
sender_id: str, sender_id: str,
action: str, action: str,
log_context: Dict[str, Any], log_context: Dict[str, Any],
agent: PRAgent): agent: PRAgent,
):
title = body.get("pull_request", {}).get("title", "") title = body.get("pull_request", {}).get("title", "")
pull_request, api_url = _check_pull_request_event(action, body, log_context) pull_request, api_url = _check_pull_request_event(action, body, log_context)
if not (pull_request and api_url): if not (pull_request and api_url):
get_logger().info(f"Invalid PR event: {action=} {api_url=}") get_logger().info(f"Invalid PR event: {action=} {api_url=}")
return {} return {}
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review'] if (
action in get_settings().github_app.handle_pr_actions
): # ['opened', 'reopened', 'ready_for_review']
# logic to ignore PRs with specific titles (e.g. "[Auto] ...") # logic to ignore PRs with specific titles (e.g. "[Auto] ...")
apply_repo_settings(api_url) apply_repo_settings(api_url)
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: if (
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context) get_identity_provider().verify_eligibility("github", sender_id, api_url)
is not Eligibility.NOT_ELIGIBLE
):
await _perform_auto_commands_github(
"pr_commands", agent, body, api_url, log_context
)
else: else:
get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}") get_logger().info(
f"User {sender=} is not eligible to process PR {api_url=}"
)
async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
async def handle_push_trigger_for_new_commits(
body: Dict[str, Any],
event: str, event: str,
sender: str, sender: str,
sender_id: str, sender_id: str,
action: str, action: str,
log_context: Dict[str, Any], log_context: Dict[str, Any],
agent: PRAgent): agent: PRAgent,
):
pull_request, api_url = _check_pull_request_event(action, body, log_context) pull_request, api_url = _check_pull_request_event(action, body, log_context)
if not (pull_request and api_url): if not (pull_request and api_url):
return {} return {}
apply_repo_settings(api_url) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event. apply_repo_settings(
api_url
) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
if not get_settings().github_app.handle_push_trigger: if not get_settings().github_app.handle_push_trigger:
return {} return {}
@ -162,7 +220,10 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
merge_commit_sha = pull_request.get("merge_commit_sha") merge_commit_sha = pull_request.get("merge_commit_sha")
if before_sha == after_sha: if before_sha == after_sha:
return {} return {}
if get_settings().github_app.push_trigger_ignore_merge_commits and after_sha == merge_commit_sha: if (
get_settings().github_app.push_trigger_ignore_merge_commits
and after_sha == merge_commit_sha
):
return {} return {}
# Prevent triggering multiple times for subsequent push triggers when one is enough: # Prevent triggering multiple times for subsequent push triggers when one is enough:
@ -172,7 +233,9 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
# more commits may have been pushed that led to the subsequent events, # more commits may have been pushed that led to the subsequent events,
# so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting. # so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting.
current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0) current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0)
max_active_tasks = 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1 max_active_tasks = (
2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
)
if current_active_tasks < max_active_tasks: if current_active_tasks < max_active_tasks:
# first task can enter, and second tasks too if backlog is enabled # first task can enter, and second tasks too if backlog is enabled
get_logger().info( get_logger().info(
@ -191,12 +254,21 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
f"Waiting to process push trigger for {api_url=} because the first task is still in progress" f"Waiting to process push trigger for {api_url=} because the first task is still in progress"
) )
await _pending_task_duplicate_push_conditions[api_url].wait() await _pending_task_duplicate_push_conditions[api_url].wait()
get_logger().info(f"Finished waiting to process push trigger for {api_url=} - continue with flow") get_logger().info(
f"Finished waiting to process push trigger for {api_url=} - continue with flow"
)
try: try:
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: if (
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}") get_identity_provider().verify_eligibility("github", sender_id, api_url)
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context) is not Eligibility.NOT_ELIGIBLE
):
get_logger().info(
f"Performing incremental review for {api_url=} because of {event=} and {action=}"
)
await _perform_auto_commands_github(
"push_commands", agent, body, api_url, log_context
)
finally: finally:
# release the waiting task block # release the waiting task block
@ -213,7 +285,12 @@ def handle_closed_pr(body, event, action, log_context):
api_url = pull_request.get("url", "") api_url = pull_request.get("url", "")
pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request) pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request)
log_context["api_url"] = api_url log_context["api_url"] = api_url
get_logger().info("PR-Agent statistics for closed PR", analytics=True, pr_statistics=pr_statistics, **log_context) get_logger().info(
"PR-Agent statistics for closed PR",
analytics=True,
pr_statistics=pr_statistics,
**log_context,
)
def get_log_context(body, event, action, build_number): def get_log_context(body, event, action, build_number):
@ -228,9 +305,18 @@ def get_log_context(body, event, action, build_number):
git_org = body.get("organization", {}).get("login", "") git_org = body.get("organization", {}).get("login", "")
installation_id = body.get("installation", {}).get("id", "") installation_id = body.get("installation", {}).get("id", "")
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown") app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app", log_context = {
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name, "action": action,
"repo": repo, "git_org": git_org, "installation_id": installation_id} "event": event,
"sender": sender,
"server_type": "github_app",
"request_id": uuid.uuid4().hex,
"build_number": build_number,
"app_name": app_name,
"repo": repo,
"git_org": git_org,
"installation_id": installation_id,
}
except Exception as e: except Exception as e:
get_logger().error("Failed to get log context", e) get_logger().error("Failed to get log context", e)
log_context = {} log_context = {}
@ -240,7 +326,10 @@ def get_log_context(body, event, action, build_number):
def is_bot_user(sender, sender_type): def is_bot_user(sender, sender_type):
try: try:
# logic to ignore PRs opened by bot # logic to ignore PRs opened by bot
if get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) and sender_type == "Bot": if (
get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False)
and sender_type == "Bot"
):
if 'pr-agent' not in sender: if 'pr-agent' not in sender:
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot") get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
return True return True
@ -262,7 +351,9 @@ def should_process_pr_logic(body) -> bool:
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", []) ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender: if ignore_pr_users and sender:
if sender in ignore_pr_users: if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting") get_logger().info(
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
)
return False return False
# logic to ignore PRs with specific titles # logic to ignore PRs with specific titles
@ -270,8 +361,12 @@ def should_process_pr_logic(body) -> bool:
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", []) ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
if not isinstance(ignore_pr_title_re, list): if not isinstance(ignore_pr_title_re, list):
ignore_pr_title_re = [ignore_pr_title_re] ignore_pr_title_re = [ignore_pr_title_re]
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re): if ignore_pr_title_re and any(
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting") re.search(regex, title) for regex in ignore_pr_title_re
):
get_logger().info(
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
)
return False return False
# logic to ignore PRs with specific labels or source branches or target branches. # logic to ignore PRs with specific labels or source branches or target branches.
@ -280,20 +375,32 @@ def should_process_pr_logic(body) -> bool:
labels = [label['name'] for label in pr_labels] labels = [label['name'] for label in pr_labels]
if any(label in ignore_pr_labels for label in labels): if any(label in ignore_pr_labels for label in labels):
labels_str = ", ".join(labels) labels_str = ", ".join(labels)
get_logger().info(f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings") get_logger().info(
f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings"
)
return False return False
# logic to ignore PRs with specific source or target branches # logic to ignore PRs with specific source or target branches
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", []) ignore_pr_source_branches = get_settings().get(
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", []) "CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
)
ignore_pr_target_branches = get_settings().get(
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
)
if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches): if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches):
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches): if any(
re.search(regex, source_branch) for regex in ignore_pr_source_branches
):
get_logger().info( get_logger().info(
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings") f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
)
return False return False
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches): if any(
re.search(regex, target_branch) for regex in ignore_pr_target_branches
):
get_logger().info( get_logger().info(
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings") f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
)
return False return False
except Exception as e: except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}") get_logger().error(f"Failed 'should_process_pr_logic': {e}")
@ -308,11 +415,15 @@ async def handle_request(body: Dict[str, Any], event: str):
body: The request body. body: The request body.
event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.). event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.).
""" """
action = body.get("action") # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize" action = body.get(
"action"
) # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
if not action: if not action:
return {} return {}
agent = PRAgent() agent = PRAgent()
log_context, sender, sender_id, sender_type = get_log_context(body, event, action, build_number) log_context, sender, sender_id, sender_type = get_log_context(
body, event, action, build_number
)
# logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches # logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches
if is_bot_user(sender, sender_type) and 'check_run' not in body: if is_bot_user(sender, sender_type) and 'check_run' not in body:
@ -327,21 +438,29 @@ async def handle_request(body: Dict[str, Any], event: str):
# handle comments on PRs # handle comments on PRs
elif action == 'created': elif action == 'created':
get_logger().debug(f'Request body', artifact=body, event=event) get_logger().debug(f'Request body', artifact=body, event=event)
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent) await handle_comments_on_pr(
body, event, sender, sender_id, action, log_context, agent
)
# handle new PRs # handle new PRs
elif event == 'pull_request' and action != 'synchronize' and action != 'closed': elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
get_logger().debug(f'Request body', artifact=body, event=event) get_logger().debug(f'Request body', artifact=body, event=event)
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent) await handle_new_pr_opened(
body, event, sender, sender_id, action, log_context, agent
)
elif event == "issue_comment" and 'edited' in action: elif event == "issue_comment" and 'edited' in action:
pass # handle_checkbox_clicked pass # handle_checkbox_clicked
# handle pull_request event with synchronize action - "push trigger" for new commits # handle pull_request event with synchronize action - "push trigger" for new commits
elif event == 'pull_request' and action == 'synchronize': elif event == 'pull_request' and action == 'synchronize':
await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent) await handle_push_trigger_for_new_commits(
body, event, sender, sender_id, action, log_context, agent
)
elif event == 'pull_request' and action == 'closed': elif event == 'pull_request' and action == 'closed':
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""): if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
handle_closed_pr(body, event, action, log_context) handle_closed_pr(body, event, action, log_context)
else: else:
get_logger().info(f"event {event=} action {action=} does not require any handling") get_logger().info(
f"event {event=} action {action=} does not require any handling"
)
return {} return {}
@ -362,7 +481,9 @@ def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str:
return comment_body return comment_body
def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tuple[Dict[str, Any], str]: def _check_pull_request_event(
action: str, body: dict, log_context: dict
) -> Tuple[Dict[str, Any], str]:
invalid_result = {}, "" invalid_result = {}, ""
pull_request = body.get("pull_request") pull_request = body.get("pull_request")
if not pull_request: if not pull_request:
@ -373,19 +494,28 @@ def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tup
log_context["api_url"] = api_url log_context["api_url"] = api_url
if pull_request.get("draft", True) or pull_request.get("state") != "open": if pull_request.get("draft", True) or pull_request.get("state") != "open":
return invalid_result return invalid_result
if action in ("review_requested", "synchronize") and pull_request.get("created_at") == pull_request.get("updated_at"): if action in ("review_requested", "synchronize") and pull_request.get(
"created_at"
) == pull_request.get("updated_at"):
# avoid double reviews when opening a PR for the first time # avoid double reviews when opening a PR for the first time
return invalid_result return invalid_result
return pull_request, api_url return pull_request, api_url
async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body: dict, api_url: str, async def _perform_auto_commands_github(
log_context: dict): commands_conf: str, agent: PRAgent, body: dict, api_url: str, log_context: dict
):
apply_repo_settings(api_url) apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled if (
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}") commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
)
return return
if not should_process_pr_logic(body): # Here we already updated the configuration with the repo settings if not should_process_pr_logic(
body
): # Here we already updated the configuration with the repo settings
return {} return {}
commands = get_settings().get(f"github_app.{commands_conf}") commands = get_settings().get(f"github_app.{commands_conf}")
if not commands: if not commands:
@ -398,7 +528,9 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body
args = split_command[1:] args = split_command[1:]
other_args = update_settings_from_args(args) other_args = update_settings_from_args(args)
new_command = ' '.join([command] + other_args) new_command = ' '.join([command] + other_args)
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}") get_logger().info(
f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}"
)
await agent.handle_request(api_url, new_command) await agent.handle_request(api_url, new_command)

View File

@ -19,10 +19,12 @@ NOTIFICATION_URL = "https://api.github.com/notifications"
async def mark_notification_as_read(headers, notification, session): async def mark_notification_as_read(headers, notification, session):
async with session.patch( async with session.patch(
f"https://api.github.com/notifications/threads/{notification['id']}", f"https://api.github.com/notifications/threads/{notification['id']}",
headers=headers) as mark_read_response: headers=headers,
) as mark_read_response:
if mark_read_response.status != 205: if mark_read_response.status != 205:
get_logger().error( get_logger().error(
f"Failed to mark notification as read. Status code: {mark_read_response.status}") f"Failed to mark notification as read. Status code: {mark_read_response.status}"
)
def now() -> str: def now() -> str:
@ -36,17 +38,21 @@ def now() -> str:
now_utc = now_utc.replace("+00:00", "Z") now_utc = now_utc.replace("+00:00", "Z")
return now_utc return now_utc
async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider): async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
agent = PRAgent() agent = PRAgent()
success = await agent.handle_request( success = await agent.handle_request(
pr_url, pr_url,
rest_of_comment, rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id) notify=lambda: git_provider.add_eyes_reaction(comment_id),
) )
return success return success
def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider): def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider)) return asyncio.run(
async_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
)
def process_comment_sync(pr_url, rest_of_comment, comment_id): def process_comment_sync(pr_url, rest_of_comment, comment_id):
@ -55,7 +61,10 @@ def process_comment_sync(pr_url, rest_of_comment, comment_id):
git_provider = get_git_provider()(pr_url=pr_url) git_provider = get_git_provider()(pr_url=pr_url)
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider) success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
except Exception as e: except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()}) get_logger().error(
f"Error processing comment: {e}",
artifact={"traceback": traceback.format_exc()},
)
async def process_comment(pr_url, rest_of_comment, comment_id): async def process_comment(pr_url, rest_of_comment, comment_id):
@ -66,22 +75,31 @@ async def process_comment(pr_url, rest_of_comment, comment_id):
success = await agent.handle_request( success = await agent.handle_request(
pr_url, pr_url,
rest_of_comment, rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id) notify=lambda: git_provider.add_eyes_reaction(comment_id),
) )
get_logger().info(f"Finished processing comment for PR: {pr_url}") get_logger().info(f"Finished processing comment for PR: {pr_url}")
except Exception as e: except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()}) get_logger().error(
f"Error processing comment: {e}",
artifact={"traceback": traceback.format_exc()},
)
async def is_valid_notification(notification, headers, handled_ids, session, user_id): async def is_valid_notification(notification, headers, handled_ids, session, user_id):
try: try:
if 'reason' in notification and notification['reason'] == 'mention': if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest': if (
'subject' in notification
and notification['subject']['type'] == 'PullRequest'
):
pr_url = notification['subject']['url'] pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url'] latest_comment = notification['subject']['latest_comment_url']
if not latest_comment or not isinstance(latest_comment, str): if not latest_comment or not isinstance(latest_comment, str):
get_logger().debug(f"no latest_comment") get_logger().debug(f"no latest_comment")
return False, handled_ids return False, handled_ids
async with session.get(latest_comment, headers=headers) as comment_response: async with session.get(
latest_comment, headers=headers
) as comment_response:
check_prev_comments = False check_prev_comments = False
user_tag = "@" + user_id user_tag = "@" + user_id
if comment_response.status == 200: if comment_response.status == 200:
@ -94,7 +112,9 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
handled_ids.add(comment['id']) handled_ids.add(comment['id'])
if 'user' in comment and 'login' in comment['user']: if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id: if comment['user']['login'] == user_id:
get_logger().debug(f"comment['user']['login'] == user_id") get_logger().debug(
f"comment['user']['login'] == user_id"
)
check_prev_comments = True check_prev_comments = True
comment_body = comment.get('body', '') comment_body = comment.get('body', '')
if not comment_body: if not comment_body:
@ -105,15 +125,28 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
get_logger().debug(f"user_tag not in comment_body") get_logger().debug(f"user_tag not in comment_body")
check_prev_comments = True check_prev_comments = True
else: else:
get_logger().info(f"Polling, pr_url: {pr_url}", get_logger().info(
artifact={"comment": comment_body}) f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body},
)
if not check_prev_comments: if not check_prev_comments:
return True, handled_ids, comment, comment_body, pr_url, user_tag return (
True,
handled_ids,
comment,
comment_body,
pr_url,
user_tag,
)
else: # we could not find the user tag in the latest comment. Check previous comments else: # we could not find the user tag in the latest comment. Check previous comments
# get all comments in the PR # get all comments in the PR
requests_url = f"{pr_url}/comments".replace("pulls", "issues") requests_url = f"{pr_url}/comments".replace(
comments_response = requests.get(requests_url, headers=headers) "pulls", "issues"
)
comments_response = requests.get(
requests_url, headers=headers
)
comments = comments_response.json()[::-1] comments = comments_response.json()[::-1]
max_comment_to_scan = 4 max_comment_to_scan = 4
for comment in comments[:max_comment_to_scan]: for comment in comments[:max_comment_to_scan]:
@ -124,23 +157,37 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
if not comment_body: if not comment_body:
continue continue
if user_tag in comment_body: if user_tag in comment_body:
get_logger().info("found user tag in previous comments") get_logger().info(
get_logger().info(f"Polling, pr_url: {pr_url}", "found user tag in previous comments"
artifact={"comment": comment_body}) )
return True, handled_ids, comment, comment_body, pr_url, user_tag get_logger().info(
f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body},
)
return (
True,
handled_ids,
comment,
comment_body,
pr_url,
user_tag,
)
get_logger().warning(f"Failed to fetch comments for PR: {pr_url}", get_logger().warning(
artifact={"comments": comments}) f"Failed to fetch comments for PR: {pr_url}",
artifact={"comments": comments},
)
return False, handled_ids return False, handled_ids
return False, handled_ids return False, handled_ids
except Exception as e: except Exception as e:
get_logger().exception(f"Error processing polling notification", get_logger().exception(
artifact={"notification": notification, "error": e}) f"Error processing polling notification",
artifact={"notification": notification, "error": e},
)
return False, handled_ids return False, handled_ids
async def polling_loop(): async def polling_loop():
""" """
Polls for notifications and handles them accordingly. Polls for notifications and handles them accordingly.
@ -171,17 +218,17 @@ async def polling_loop():
await asyncio.sleep(5) await asyncio.sleep(5)
headers = { headers = {
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
"Authorization": f"Bearer {token}" "Authorization": f"Bearer {token}",
}
params = {
"participating": "true"
} }
params = {"participating": "true"}
if since[0]: if since[0]:
params["since"] = since[0] params["since"] = since[0]
if last_modified[0]: if last_modified[0]:
headers["If-Modified-Since"] = last_modified[0] headers["If-Modified-Since"] = last_modified[0]
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response: async with session.get(
NOTIFICATION_URL, headers=headers, params=params
) as response:
if response.status == 200: if response.status == 200:
if 'Last-Modified' in response.headers: if 'Last-Modified' in response.headers:
last_modified[0] = response.headers['Last-Modified'] last_modified[0] = response.headers['Last-Modified']
@ -189,39 +236,67 @@ async def polling_loop():
notifications = await response.json() notifications = await response.json()
if not notifications: if not notifications:
continue continue
get_logger().info(f"Received {len(notifications)} notifications") get_logger().info(
f"Received {len(notifications)} notifications"
)
task_queue = deque() task_queue = deque()
for notification in notifications: for notification in notifications:
if not notification: if not notification:
continue continue
# mark notification as read # mark notification as read
await mark_notification_as_read(headers, notification, session) await mark_notification_as_read(
headers, notification, session
)
handled_ids.add(notification['id']) handled_ids.add(notification['id'])
output = await is_valid_notification(notification, headers, handled_ids, session, user_id) output = await is_valid_notification(
notification, headers, handled_ids, session, user_id
)
if output[0]: if output[0]:
_, handled_ids, comment, comment_body, pr_url, user_tag = output (
rest_of_comment = comment_body.split(user_tag)[1].strip() _,
handled_ids,
comment,
comment_body,
pr_url,
user_tag,
) = output
rest_of_comment = comment_body.split(user_tag)[
1
].strip()
comment_id = comment['id'] comment_id = comment['id']
# Add to the task queue # Add to the task queue
get_logger().info( get_logger().info(
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}") f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}"
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id))) )
get_logger().info(f"Queued comment processing for PR: {pr_url}") task_queue.append(
(
process_comment_sync,
(pr_url, rest_of_comment, comment_id),
)
)
get_logger().info(
f"Queued comment processing for PR: {pr_url}"
)
else: else:
get_logger().debug(f"Skipping comment processing for PR") get_logger().debug(
f"Skipping comment processing for PR"
)
max_allowed_parallel_tasks = 10 max_allowed_parallel_tasks = 10
if task_queue: if task_queue:
processes = [] processes = []
for i, (func, args) in enumerate(task_queue): # Create parallel tasks for i, (func, args) in enumerate(
task_queue
): # Create parallel tasks
p = multiprocessing.Process(target=func, args=args) p = multiprocessing.Process(target=func, args=args)
processes.append(p) processes.append(p)
p.start() p.start()
if i > max_allowed_parallel_tasks: if i > max_allowed_parallel_tasks:
get_logger().error( get_logger().error(
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session") f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session"
)
break break
task_queue.clear() task_queue.clear()
@ -230,11 +305,15 @@ async def polling_loop():
# p.join() # p.join()
elif response.status != 304: elif response.status != 304:
print(f"Failed to fetch notifications. Status code: {response.status}") print(
f"Failed to fetch notifications. Status code: {response.status}"
)
except Exception as e: except Exception as e:
get_logger().error(f"Polling exception during processing of a notification: {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Polling exception during processing of a notification: {e}",
artifact={"traceback": traceback.format_exc()},
)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -22,20 +22,21 @@ from utils.pr_agent.secret_providers import get_secret_provider
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG") setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
router = APIRouter() router = APIRouter()
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None secret_provider = (
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
)
async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id): async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
try: try:
import requests import requests
headers = {
'Private-Token': f'{gitlab_token}' headers = {'Private-Token': f'{gitlab_token}'}
}
# API endpoint to find MRs containing the commit # API endpoint to find MRs containing the commit
gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com') gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com')
response = requests.get( response = requests.get(
f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests', f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests',
headers=headers headers=headers,
) )
merge_requests = response.json() merge_requests = response.json()
if merge_requests and response.status_code == 200: if merge_requests and response.status_code == 200:
@ -48,6 +49,7 @@ async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
get_logger().error(f"Failed to get MR url from commit sha: {e}") get_logger().error(f"Failed to get MR url from commit sha: {e}")
return None return None
async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str): async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str):
log_context["action"] = body log_context["action"] = body
log_context["event"] = "pull_request" if body == "/review" else "comment" log_context["event"] = "pull_request" if body == "/review" else "comment"
@ -58,11 +60,17 @@ async def handle_request(api_url: str, body: str, log_context: dict, sender_id:
await PRAgent().handle_request(api_url, body) await PRAgent().handle_request(api_url, body)
async def _perform_commands_gitlab(commands_conf: str, agent: PRAgent, api_url: str, async def _perform_commands_gitlab(
log_context: dict, data: dict): commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
):
apply_repo_settings(api_url) apply_repo_settings(api_url)
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled if (
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context) commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
): # auto commands for PR, and auto feedback is disabled
get_logger().info(
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
**log_context,
)
return return
if not should_process_pr_logic(data): # Here we already updated the configurations if not should_process_pr_logic(data): # Here we already updated the configurations
return return
@ -106,40 +114,58 @@ def should_process_pr_logic(data) -> bool:
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", []) ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
if ignore_pr_users and sender: if ignore_pr_users and sender:
if sender in ignore_pr_users: if sender in ignore_pr_users:
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings") get_logger().info(
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings"
)
return False return False
# logic to ignore MRs for titles, labels and source, target branches. # logic to ignore MRs for titles, labels and source, target branches.
ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", []) ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", []) ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
ignore_mr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", []) ignore_mr_source_branches = get_settings().get(
ignore_mr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", []) "CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
)
ignore_mr_target_branches = get_settings().get(
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
)
# #
if ignore_mr_source_branches: if ignore_mr_source_branches:
source_branch = data['object_attributes'].get('source_branch') source_branch = data['object_attributes'].get('source_branch')
if any(re.search(regex, source_branch) for regex in ignore_mr_source_branches): if any(
re.search(regex, source_branch) for regex in ignore_mr_source_branches
):
get_logger().info( get_logger().info(
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings") f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings"
)
return False return False
if ignore_mr_target_branches: if ignore_mr_target_branches:
target_branch = data['object_attributes'].get('target_branch') target_branch = data['object_attributes'].get('target_branch')
if any(re.search(regex, target_branch) for regex in ignore_mr_target_branches): if any(
re.search(regex, target_branch) for regex in ignore_mr_target_branches
):
get_logger().info( get_logger().info(
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings") f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings"
)
return False return False
if ignore_mr_labels: if ignore_mr_labels:
labels = [label['title'] for label in data['object_attributes'].get('labels', [])] labels = [
label['title'] for label in data['object_attributes'].get('labels', [])
]
if any(label in ignore_mr_labels for label in labels): if any(label in ignore_mr_labels for label in labels):
labels_str = ", ".join(labels) labels_str = ", ".join(labels)
get_logger().info(f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings") get_logger().info(
f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings"
)
return False return False
if ignore_mr_title: if ignore_mr_title:
if any(re.search(regex, title) for regex in ignore_mr_title): if any(re.search(regex, title) for regex in ignore_mr_title):
get_logger().info(f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings") get_logger().info(
f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings"
)
return False return False
except Exception as e: except Exception as e:
get_logger().error(f"Failed 'should_process_pr_logic': {e}") get_logger().error(f"Failed 'should_process_pr_logic': {e}")
@ -159,29 +185,47 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
request_token = request.headers.get("X-Gitlab-Token") request_token = request.headers.get("X-Gitlab-Token")
secret = secret_provider.get_secret(request_token) secret = secret_provider.get_secret(request_token)
if not secret: if not secret:
get_logger().warning(f"Empty secret retrieved, request_token: {request_token}") get_logger().warning(
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, f"Empty secret retrieved, request_token: {request_token}"
content=jsonable_encoder({"message": "unauthorized"})) )
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
try: try:
secret_dict = json.loads(secret) secret_dict = json.loads(secret)
gitlab_token = secret_dict["gitlab_token"] gitlab_token = secret_dict["gitlab_token"]
log_context["token_id"] = secret_dict.get("token_name", secret_dict.get("id", "unknown")) log_context["token_id"] = secret_dict.get(
"token_name", secret_dict.get("id", "unknown")
)
context["settings"].gitlab.personal_access_token = gitlab_token context["settings"].gitlab.personal_access_token = gitlab_token
except Exception as e: except Exception as e:
get_logger().error(f"Failed to validate secret {request_token}: {e}") get_logger().error(f"Failed to validate secret {request_token}: {e}")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
elif get_settings().get("GITLAB.SHARED_SECRET"): elif get_settings().get("GITLAB.SHARED_SECRET"):
secret = get_settings().get("GITLAB.SHARED_SECRET") secret = get_settings().get("GITLAB.SHARED_SECRET")
if not request.headers.get("X-Gitlab-Token") == secret: if not request.headers.get("X-Gitlab-Token") == secret:
get_logger().error("Failed to validate secret") get_logger().error("Failed to validate secret")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
else: else:
get_logger().error("Failed to validate secret") get_logger().error("Failed to validate secret")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None) gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
if not gitlab_token: if not gitlab_token:
get_logger().error("No gitlab token found") get_logger().error("No gitlab token found")
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"})) return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=jsonable_encoder({"message": "unauthorized"}),
)
get_logger().info("GitLab data", artifact=data) get_logger().info("GitLab data", artifact=data)
sender = data.get("user", {}).get("username", "unknown") sender = data.get("user", {}).get("username", "unknown")
@ -189,31 +233,49 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
# ignore bot users # ignore bot users
if is_bot_user(data): if is_bot_user(data):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
if data.get('event_type') != 'note': # not a comment if data.get('event_type') != 'note': # not a comment
# ignore MRs based on title, labels, source and target branches # ignore MRs based on title, labels, source and target branches
if not should_process_pr_logic(data): if not should_process_pr_logic(data):
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
log_context["sender"] = sender log_context["sender"] = sender
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']: if data.get('object_kind') == 'merge_request' and data['object_attributes'].get(
'action'
) in ['open', 'reopen']:
title = data['object_attributes'].get('title') title = data['object_attributes'].get('title')
url = data['object_attributes'].get('url') url = data['object_attributes'].get('url')
draft = data['object_attributes'].get('draft') draft = data['object_attributes'].get('draft')
get_logger().info(f"New merge request: {url}") get_logger().info(f"New merge request: {url}")
if draft: if draft:
get_logger().info(f"Skipping draft MR: {url}") get_logger().info(f"Skipping draft MR: {url}")
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
await _perform_commands_gitlab("pr_commands", PRAgent(), url, log_context, data) await _perform_commands_gitlab(
elif data.get('object_kind') == 'note' and data.get('event_type') == 'note': # comment on MR "pr_commands", PRAgent(), url, log_context, data
)
elif (
data.get('object_kind') == 'note' and data.get('event_type') == 'note'
): # comment on MR
if 'merge_request' in data: if 'merge_request' in data:
mr = data['merge_request'] mr = data['merge_request']
url = mr.get('url') url = mr.get('url')
get_logger().info(f"A comment has been added to a merge request: {url}") get_logger().info(f"A comment has been added to a merge request: {url}")
body = data.get('object_attributes', {}).get('note') body = data.get('object_attributes', {}).get('note')
if data.get('object_attributes', {}).get('type') == 'DiffNote' and '/ask' in body: # /ask_line if (
data.get('object_attributes', {}).get('type') == 'DiffNote'
and '/ask' in body
): # /ask_line
body = handle_ask_line(body, data) body = handle_ask_line(body, data)
await handle_request(url, body, log_context, sender_id) await handle_request(url, body, log_context, sender_id)
@ -221,30 +283,44 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
try: try:
project_id = data['project_id'] project_id = data['project_id']
commit_sha = data['checkout_sha'] commit_sha = data['checkout_sha']
url = await get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id) url = await get_mr_url_from_commit_sha(
commit_sha, gitlab_token, project_id
)
if not url: if not url:
get_logger().info(f"No MR found for commit: {commit_sha}") get_logger().info(f"No MR found for commit: {commit_sha}")
return JSONResponse(status_code=status.HTTP_200_OK, return JSONResponse(
content=jsonable_encoder({"message": "success"})) status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
# we need first to apply_repo_settings # we need first to apply_repo_settings
apply_repo_settings(url) apply_repo_settings(url)
commands_on_push = get_settings().get(f"gitlab.push_commands", {}) commands_on_push = get_settings().get(f"gitlab.push_commands", {})
handle_push_trigger = get_settings().get(f"gitlab.handle_push_trigger", False) handle_push_trigger = get_settings().get(
f"gitlab.handle_push_trigger", False
)
if not commands_on_push or not handle_push_trigger: if not commands_on_push or not handle_push_trigger:
get_logger().info("Push event, but no push commands found or push trigger is disabled") get_logger().info(
return JSONResponse(status_code=status.HTTP_200_OK, "Push event, but no push commands found or push trigger is disabled"
content=jsonable_encoder({"message": "success"})) )
return JSONResponse(
status_code=status.HTTP_200_OK,
content=jsonable_encoder({"message": "success"}),
)
get_logger().debug(f'A push event has been received: {url}') get_logger().debug(f'A push event has been received: {url}')
await _perform_commands_gitlab("push_commands", PRAgent(), url, log_context, data) await _perform_commands_gitlab(
"push_commands", PRAgent(), url, log_context, data
)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to handle push event: {e}") get_logger().error(f"Failed to handle push event: {e}")
background_tasks.add_task(inner, request_json) background_tasks.add_task(inner, request_json)
end_time = datetime.now() end_time = datetime.now()
get_logger().info(f"Processing time: {end_time - start_time}", request=request_json) get_logger().info(f"Processing time: {end_time - start_time}", request=request_json)
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})) return JSONResponse(
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})
)
def handle_ask_line(body, data): def handle_ask_line(body, data):
@ -271,6 +347,7 @@ def handle_ask_line(body, data):
async def root(): async def root():
return {"status": "ok"} return {"status": "ok"}
gitlab_url = get_settings().get("GITLAB.URL", None) gitlab_url = get_settings().get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
raise ValueError("GITLAB.URL is not set") raise ValueError("GITLAB.URL is not set")

View File

@ -1,19 +1,20 @@
class HelpMessage: class HelpMessage:
@staticmethod @staticmethod
def get_general_commands_text(): def get_general_commands_text():
commands_text = "> - **/review**: Request a review of your Pull Request. \n" \ commands_text = (
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n" \ "> - **/review**: Request a review of your Pull Request. \n"
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" \ "> - **/describe**: Update the PR title and description based on the contents of the PR. \n"
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n" \ "> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n"
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n" \ "> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n"
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" \ "> - **/update_changelog**: Update the changelog based on the PR's contents. \n"
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" \ "> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n"
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" \ "> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n"
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" \ "> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n"
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n"
">To list the possible configuration parameters, add a **/config** comment. \n" ">To list the possible configuration parameters, add a **/config** comment. \n"
)
return commands_text return commands_text
@staticmethod @staticmethod
def get_general_bot_help_text(): def get_general_bot_help_text():
output = f"> To invoke the PR-Agent, add a comment using one of the following commands: \n{HelpMessage.get_general_commands_text()} \n" output = f"> To invoke the PR-Agent, add a comment using one of the following commands: \n{HelpMessage.get_general_commands_text()} \n"
@ -22,8 +23,10 @@ class HelpMessage:
@staticmethod @staticmethod
def get_review_usage_guide(): def get_review_usage_guide():
output = "**Overview:**\n" output = "**Overview:**\n"
output +=("The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n" output += (
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n") "The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n"
)
output += """\ output += """\
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template: - When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template:
``` ```
@ -41,8 +44,6 @@ some_config2=...
return output return output
@staticmethod @staticmethod
def get_describe_usage_guide(): def get_describe_usage_guide():
output = "**Overview:**\n" output = "**Overview:**\n"
@ -137,7 +138,6 @@ Use triple quotes to write multi-line instructions. Use bullet points to make th
''' '''
output += "\n\n</details></td></tr>\n\n" output += "\n\n</details></td></tr>\n\n"
# general # general
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n" output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
output += HelpMessage.get_general_bot_help_text() output += HelpMessage.get_general_bot_help_text()
@ -175,7 +175,6 @@ You can ask questions about the entire PR, about specific code lines, or about a
return output return output
@staticmethod @staticmethod
def get_improve_usage_guide(): def get_improve_usage_guide():
output = "**Overview:**\n" output = "**Overview:**\n"

View File

@ -18,8 +18,12 @@ def verify_signature(payload_body, secret_token, signature_header):
signature_header: header received from GitHub (x-hub-signature-256) signature_header: header received from GitHub (x-hub-signature-256)
""" """
if not signature_header: if not signature_header:
raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!") raise HTTPException(
hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256) status_code=403, detail="x-hub-signature-256 header is missing!"
)
hash_object = hmac.new(
secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256
)
expected_signature = "sha256=" + hash_object.hexdigest() expected_signature = "sha256=" + hash_object.hexdigest()
if not hmac.compare_digest(expected_signature, signature_header): if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!") raise HTTPException(status_code=403, detail="Request signatures didn't match!")
@ -27,6 +31,7 @@ def verify_signature(payload_body, secret_token, signature_header):
class RateLimitExceeded(Exception): class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded.""" """Raised when the git provider API rate limit has been exceeded."""
pass pass
@ -66,7 +71,11 @@ class DefaultDictWithTimeout(defaultdict):
request_time = self.__time() request_time = self.__time()
if request_time - self.__last_refresh > self.__refresh_interval: if request_time - self.__last_refresh > self.__refresh_interval:
return return
to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl] to_delete = [
key
for key, key_time in self.__key_times.items()
if request_time - key_time > self.__ttl
]
for key in to_delete: for key in to_delete:
del self[key] del self[key]
self.__last_refresh = request_time self.__last_refresh = request_time

View File

@ -17,9 +17,13 @@ from utils.pr_agent.log import get_logger
class PRAddDocs: class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, def __init__(
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self,
pr_url: str,
cli_mode=False,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
@ -39,13 +43,16 @@ class PRAddDocs:
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"extra_instructions": get_settings().pr_add_docs.extra_instructions, "extra_instructions": get_settings().pr_add_docs.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
'docs_for_language': get_docs_for_language(self.main_language, 'docs_for_language': get_docs_for_language(
get_settings().pr_add_docs.docs_style), self.main_language, get_settings().pr_add_docs.docs_style
),
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars, self.vars,
get_settings().pr_add_docs_prompt.system, get_settings().pr_add_docs_prompt.system,
get_settings().pr_add_docs_prompt.user) get_settings().pr_add_docs_prompt.user,
)
async def run(self): async def run(self):
try: try:
@ -66,16 +73,20 @@ class PRAddDocs:
get_logger().info('Pushing inline code documentation...') get_logger().info('Pushing inline code documentation...')
self.push_inline_docs(data) self.push_inline_docs(data)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to generate code documentation for PR, error: {e}") get_logger().error(
f"Failed to generate code documentation for PR, error: {e}"
)
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...') get_logger().info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider, self.patches_diff = get_pr_diff(
self.git_provider,
self.token_handler, self.token_handler,
model, model,
add_line_numbers_to_hunks=True, add_line_numbers_to_hunks=True,
disable_extra_lines=False) disable_extra_lines=False,
)
get_logger().info('Getting AI prediction...') get_logger().info('Getting AI prediction...')
self.prediction = await self._get_prediction(model) self.prediction = await self._get_prediction(model)
@ -84,13 +95,21 @@ class PRAddDocs:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables) get_settings().pr_add_docs_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_add_docs_prompt.user
).render(variables)
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nSystem prompt:\n{system_prompt}") get_logger().info(f"\nSystem prompt:\n{system_prompt}")
get_logger().info(f"\nUser prompt:\n{user_prompt}") get_logger().info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response return response
@ -105,7 +124,9 @@ class PRAddDocs:
docs = [] docs = []
if not data['Code Documentation']: if not data['Code Documentation']:
return self.git_provider.publish_comment('No code documentation found to improve this PR.') return self.git_provider.publish_comment(
'No code documentation found to improve this PR.'
)
for d in data['Code Documentation']: for d in data['Code Documentation']:
try: try:
@ -116,32 +137,59 @@ class PRAddDocs:
documentation = d['documentation'] documentation = d['documentation']
doc_placement = d['doc placement'].strip() doc_placement = d['doc placement'].strip()
if documentation: if documentation:
new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement, new_code_snippet = self.dedent_code(
add_original_line=True) relevant_file,
relevant_line,
documentation,
doc_placement,
add_original_line=True,
)
body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```" body = (
docs.append({'body': body, 'relevant_file': relevant_file, f"**Suggestion:** Proposed documentation\n```suggestion\n"
+ new_code_snippet
+ "\n```"
)
docs.append(
{
'body': body,
'relevant_file': relevant_file,
'relevant_lines_start': relevant_line, 'relevant_lines_start': relevant_line,
'relevant_lines_end': relevant_line}) 'relevant_lines_end': relevant_line,
}
)
except Exception: except Exception:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not parse code docs: {d}") get_logger().info(f"Could not parse code docs: {d}")
is_successful = self.git_provider.publish_code_suggestions(docs) is_successful = self.git_provider.publish_code_suggestions(docs)
if not is_successful: if not is_successful:
get_logger().info("Failed to publish code docs, trying to publish each docs separately") get_logger().info(
"Failed to publish code docs, trying to publish each docs separately"
)
for doc_suggestion in docs: for doc_suggestion in docs:
self.git_provider.publish_code_suggestions([doc_suggestion]) self.git_provider.publish_code_suggestions([doc_suggestion])
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after', def dedent_code(
add_original_line=False): self,
relevant_file,
relevant_lines_start,
new_code_snippet,
doc_placement='after',
add_original_line=False,
):
try: # dedent code snippet try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \ self.diff_files = (
self.git_provider.diff_files
if self.git_provider.diff_files
else self.git_provider.get_diff_files() else self.git_provider.get_diff_files()
)
original_initial_line = None original_initial_line = None
for file in self.diff_files: for file in self.diff_files:
if file.filename.strip() == relevant_file: if file.filename.strip() == relevant_file:
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1] original_initial_line = file.head_file.splitlines()[
relevant_lines_start - 1
]
break break
if original_initial_line: if original_initial_line:
if doc_placement == 'after': if doc_placement == 'after':
@ -150,18 +198,28 @@ class PRAddDocs:
line = original_initial_line line = original_initial_line
suggested_initial_line = new_code_snippet.splitlines()[0] suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(line) - len(line.lstrip()) original_initial_spaces = len(line) - len(line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip()) suggested_initial_spaces = len(suggested_initial_line) - len(
suggested_initial_line.lstrip()
)
delta_spaces = original_initial_spaces - suggested_initial_spaces delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0: if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n') new_code_snippet = textwrap.indent(
new_code_snippet, delta_spaces * " "
).rstrip('\n')
if add_original_line: if add_original_line:
if doc_placement == 'after': if doc_placement == 'after':
new_code_snippet = original_initial_line + "\n" + new_code_snippet new_code_snippet = (
original_initial_line + "\n" + new_code_snippet
)
else: else:
new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line new_code_snippet = (
new_code_snippet.rstrip() + "\n" + original_initial_line
)
except Exception as e: except Exception as e:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}") get_logger().info(
f"Could not dedent code snippet for file {relevant_file}, error: {e}"
)
return new_code_snippet return new_code_snippet

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ class PRConfig:
""" """
The PRConfig class is responsible for listing all configuration options available for the user. The PRConfig class is responsible for listing all configuration options available for the user.
""" """
def __init__(self, pr_url: str, args=None, ai_handler=None): def __init__(self, pr_url: str, args=None, ai_handler=None):
""" """
Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request. Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request.
@ -34,20 +35,43 @@ class PRConfig:
conf_settings = Dynaconf(settings_files=[conf_file]) conf_settings = Dynaconf(settings_files=[conf_file])
configuration_headers = [header.lower() for header in conf_settings.keys()] configuration_headers = [header.lower() for header in conf_settings.keys()]
relevant_configs = { relevant_configs = {
header: configs for header, configs in get_settings().to_dict().items() header: configs
if (header.lower().startswith("pr_") or header.lower().startswith("config")) and header.lower() in configuration_headers for header, configs in get_settings().to_dict().items()
if (header.lower().startswith("pr_") or header.lower().startswith("config"))
and header.lower() in configuration_headers
} }
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect", skip_keys = [
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS', 'ai_disclaimer',
'APP_NAME', 'PERSONAL_ACCESS_TOKEN', 'shared_secret', 'key', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'user_token', 'ai_disclaimer_title',
'private_key', 'private_key_id', 'client_id', 'client_secret', 'token', 'bearer_token'] 'ANALYTICS_FOLDER',
'secret_provider',
"skip_keys",
"app_id",
"redirect",
'trial_prefix_message',
'no_eligible_message',
'identity_provider',
'ALLOWED_REPOS',
'APP_NAME',
'PERSONAL_ACCESS_TOKEN',
'shared_secret',
'key',
'AWS_ACCESS_KEY_ID',
'AWS_SECRET_ACCESS_KEY',
'user_token',
'private_key',
'private_key_id',
'client_id',
'client_secret',
'token',
'bearer_token',
]
extra_skip_keys = get_settings().config.get('config.skip_keys', []) extra_skip_keys = get_settings().config.get('config.skip_keys', [])
if extra_skip_keys: if extra_skip_keys:
skip_keys.extend(extra_skip_keys) skip_keys.extend(extra_skip_keys)
skip_keys_lower = [key.lower() for key in skip_keys] skip_keys_lower = [key.lower() for key in skip_keys]
markdown_text = "<details> <summary><strong>🛠️ PR-Agent Configurations:</strong></summary> \n\n" markdown_text = "<details> <summary><strong>🛠️ PR-Agent Configurations:</strong></summary> \n\n"
markdown_text += f"\n\n```yaml\n\n" markdown_text += f"\n\n```yaml\n\n"
for header, configs in relevant_configs.items(): for header, configs in relevant_configs.items():
@ -61,5 +85,7 @@ class PRConfig:
markdown_text += " " markdown_text += " "
markdown_text += "\n```" markdown_text += "\n```"
markdown_text += "\n</details>\n" markdown_text += "\n</details>\n"
get_logger().info(f"Possible Configurations outputted to PR comment", artifact=markdown_text) get_logger().info(
f"Possible Configurations outputted to PR comment", artifact=markdown_text
)
return markdown_text return markdown_text

View File

@ -10,27 +10,38 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.pr_processing import (OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD, from utils.pr_agent.algo.pr_processing import (
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
get_pr_diff, get_pr_diff,
get_pr_diff_multiple_patchs, get_pr_diff_multiple_patchs,
retry_with_fallback_models) retry_with_fallback_models,
)
from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens, from utils.pr_agent.algo.utils import (
get_max_tokens, get_user_labels, load_yaml, ModelType,
PRDescriptionHeader,
clip_tokens,
get_max_tokens,
get_user_labels,
load_yaml,
set_custom_labels, set_custom_labels,
show_relevant_configurations) show_relevant_configurations,
)
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (GithubProvider, get_git_provider_with_context) from utils.pr_agent.git_providers import GithubProvider, get_git_provider_with_context
from utils.pr_agent.git_providers.git_provider import get_main_pr_language from utils.pr_agent.git_providers.git_provider import get_main_pr_language
from utils.pr_agent.log import get_logger from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage from utils.pr_agent.servers.help import HelpMessage
from utils.pr_agent.tools.ticket_pr_compliance_check import ( from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
extract_and_cache_pr_tickets)
class PRDescription: class PRDescription:
def __init__(self, pr_url: str, args: list = None, def __init__(
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self,
pr_url: str,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model. using an AI model.
@ -44,11 +55,22 @@ class PRDescription:
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.pr_id = self.git_provider.get_pr_id() self.pr_id = self.git_provider.get_pr_id()
self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"] self.keys_fix = [
"filename:",
"language:",
"changes_summary:",
"changes_title:",
"description:",
"title:",
]
if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported( if (
"gfm_markdown"): get_settings().pr_description.enable_semantic_files_types
get_logger().debug(f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported.") and not self.git_provider.is_supported("gfm_markdown")
):
get_logger().debug(
f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported."
)
get_settings().pr_description.enable_semantic_files_types = False get_settings().pr_description.enable_semantic_files_types = False
# Initialize the AI handler # Initialize the AI handler
@ -56,7 +78,9 @@ class PRDescription:
self.ai_handler.main_pr_language = self.main_pr_language self.ai_handler.main_pr_language = self.main_pr_language
# Initialize the variables dictionary # Initialize the variables dictionary
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get("collapsible_file_list_threshold", 8) self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get(
"collapsible_file_list_threshold", 8
)
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),
@ -69,8 +93,11 @@ class PRDescription:
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function "custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types, "enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
"related_tickets": "", "related_tickets": "",
"include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD, "include_file_summary_changes": len(self.git_provider.get_diff_files())
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False), <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
'duplicate_prompt_examples': get_settings().config.get(
'duplicate_prompt_examples', False
),
} }
self.user_description = self.git_provider.get_user_description() self.user_description = self.git_provider.get_user_description()
@ -91,10 +118,14 @@ class PRDescription:
async def run(self): async def run(self):
try: try:
get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}") get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}")
relevant_configs = {'pr_description': dict(get_settings().pr_description), relevant_configs = {
'config': dict(get_settings().config)} 'pr_description': dict(get_settings().pr_description),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs) get_logger().debug("Relevant configs", artifacts=relevant_configs)
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False): if get_settings().config.publish_output and not get_settings().config.get(
'is_auto_command', False
):
self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True) self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True)
# ticket extraction if exists # ticket extraction if exists
@ -119,40 +150,73 @@ class PRDescription:
get_logger().debug(f"Publishing labels disabled") get_logger().debug(f"Publishing labels disabled")
if get_settings().pr_description.use_description_markers: if get_settings().pr_description.use_description_markers:
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer_with_markers() (
pr_title,
pr_body,
changes_walkthrough,
pr_file_changes,
) = self._prepare_pr_answer_with_markers()
else: else:
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer() (
if not self.git_provider.is_supported( pr_title,
"publish_file_comments") or not get_settings().pr_description.inline_file_summary: pr_body,
changes_walkthrough,
pr_file_changes,
) = self._prepare_pr_answer()
if (
not self.git_provider.is_supported("publish_file_comments")
or not get_settings().pr_description.inline_file_summary
):
pr_body += "\n\n" + changes_walkthrough pr_body += "\n\n" + changes_walkthrough
get_logger().debug("PR output", artifact={"title": pr_title, "body": pr_body}) get_logger().debug(
"PR output", artifact={"title": pr_title, "body": pr_body}
)
# Add help text if gfm_markdown is supported # Add help text if gfm_markdown is supported
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_description.enable_help_text: if (
self.git_provider.is_supported("gfm_markdown")
and get_settings().pr_description.enable_help_text
):
pr_body += "<hr>\n\n<details> <summary><strong>✨ 工具使用指南:</strong></summary><hr> \n\n" pr_body += "<hr>\n\n<details> <summary><strong>✨ 工具使用指南:</strong></summary><hr> \n\n"
pr_body += HelpMessage.get_describe_usage_guide() pr_body += HelpMessage.get_describe_usage_guide()
pr_body += "\n</details>\n" pr_body += "\n</details>\n"
elif get_settings().pr_description.enable_help_comment and self.git_provider.is_supported("gfm_markdown"): elif (
get_settings().pr_description.enable_help_comment
and self.git_provider.is_supported("gfm_markdown")
):
if isinstance(self.git_provider, GithubProvider): if isinstance(self.git_provider, GithubProvider):
pr_body += ('\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> ' pr_body += (
'\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> '
'关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 ' '关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 '
'<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> ' '<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> '
'了解更多.</li></details>') '了解更多.</li></details>'
)
else: # gitlab else: # gitlab
pr_body += ("\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 " pr_body += (
"\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 "
"关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 " "关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 "
"<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>") "<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>"
)
# elif get_settings().pr_description.enable_help_comment: # elif get_settings().pr_description.enable_help_comment:
# pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information' # pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information'
# Output the relevant configurations if enabled # Output the relevant configurations if enabled
if get_settings().get('config', {}).get('output_relevant_configurations', False): if (
pr_body += show_relevant_configurations(relevant_section='pr_description') get_settings()
.get('config', {})
.get('output_relevant_configurations', False)
):
pr_body += show_relevant_configurations(
relevant_section='pr_description'
)
if get_settings().config.publish_output: if get_settings().config.publish_output:
# publish labels # publish labels
if get_settings().pr_description.publish_labels and pr_labels and self.git_provider.is_supported("get_labels"): if (
get_settings().pr_description.publish_labels
and pr_labels
and self.git_provider.is_supported("get_labels")
):
original_labels = self.git_provider.get_pr_labels(update=True) original_labels = self.git_provider.get_pr_labels(update=True)
get_logger().debug(f"original labels", artifact=original_labels) get_logger().debug(f"original labels", artifact=original_labels)
user_labels = get_user_labels(original_labels) user_labels = get_user_labels(original_labels)
@ -165,20 +229,29 @@ class PRDescription:
# publish description # publish description
if get_settings().pr_description.publish_description_as_comment: if get_settings().pr_description.publish_description_as_comment:
full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}" full_markdown_description = (
if get_settings().pr_description.publish_description_as_comment_persistent: f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
self.git_provider.publish_persistent_comment(full_markdown_description, )
if (
get_settings().pr_description.publish_description_as_comment_persistent
):
self.git_provider.publish_persistent_comment(
full_markdown_description,
initial_header="## Title", initial_header="## Title",
update_header=True, update_header=True,
name="describe", name="describe",
final_update_message=False, ) final_update_message=False,
)
else: else:
self.git_provider.publish_comment(full_markdown_description) self.git_provider.publish_comment(full_markdown_description)
else: else:
self.git_provider.publish_description(pr_title, pr_body) self.git_provider.publish_description(pr_title, pr_body)
# publish final update message # publish final update message
if (get_settings().pr_description.final_update_message and not get_settings().config.get('is_auto_command', False)): if (
get_settings().pr_description.final_update_message
and not get_settings().config.get('is_auto_command', False)
):
latest_commit_url = self.git_provider.get_latest_commit_url() latest_commit_url = self.git_provider.get_latest_commit_url()
if latest_commit_url: if latest_commit_url:
pr_url = self.git_provider.get_pr_url() pr_url = self.git_provider.get_pr_url()
@ -186,22 +259,40 @@ class PRDescription:
self.git_provider.publish_comment(update_comment) self.git_provider.publish_comment(update_comment)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
else: else:
get_logger().info('PR description, but not published since publish_output is False.') get_logger().info(
'PR description, but not published since publish_output is False.'
)
get_settings().data = {"artifact": pr_body} get_settings().data = {"artifact": pr_body}
return return
except Exception as e: except Exception as e:
get_logger().error(f"Error generating PR description {self.pr_id}: {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Error generating PR description {self.pr_id}: {e}",
artifact={"traceback": traceback.format_exc()},
)
return "" return ""
async def _prepare_prediction(self, model: str) -> None: async def _prepare_prediction(self, model: str) -> None:
if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description: if (
get_logger().info("Markers were enabled, but user description does not contain markers. skipping AI prediction") get_settings().pr_description.use_description_markers
and 'pr_agent:' not in self.user_description
):
get_logger().info(
"Markers were enabled, but user description does not contain markers. skipping AI prediction"
)
return None return None
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() large_pr_handling = (
output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True) get_settings().pr_description.enable_large_pr_handling
and "pr_description_only_files_prompts" in get_settings()
)
output = get_pr_diff(
self.git_provider,
self.token_handler,
model,
large_pr_handling=large_pr_handling,
return_remaining_files=True,
)
if isinstance(output, tuple): if isinstance(output, tuple):
patches_diff, remaining_files_list = output patches_diff, remaining_files_list = output
else: else:
@ -213,14 +304,18 @@ class PRDescription:
if patches_diff: if patches_diff:
# generate the prediction # generate the prediction
get_logger().debug(f"PR diff", artifact=self.patches_diff) get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt") self.prediction = await self._get_prediction(
model, patches_diff, prompt="pr_description_prompt"
)
# extend the prediction with additional files not shown # extend the prediction with additional files not shown
if get_settings().pr_description.enable_semantic_files_types: if get_settings().pr_description.enable_semantic_files_types:
self.prediction = await self.extend_uncovered_files(self.prediction) self.prediction = await self.extend_uncovered_files(self.prediction)
else: else:
get_logger().error(f"Error getting PR diff {self.pr_id}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Error getting PR diff {self.pr_id}",
artifact={"traceback": traceback.format_exc()},
)
self.prediction = None self.prediction = None
else: else:
# get the diff in multiple patches, with the token handler only for the files prompt # get the diff in multiple patches, with the token handler only for the files prompt
@ -231,9 +326,16 @@ class PRDescription:
get_settings().pr_description_only_files_prompts.system, get_settings().pr_description_only_files_prompts.system,
get_settings().pr_description_only_files_prompts.user, get_settings().pr_description_only_files_prompts.user,
) )
(patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, (
files_in_patches_list) = get_pr_diff_multiple_patchs( patches_compressed_list,
self.git_provider, token_handler_only_files_prompt, model) total_tokens_list,
deleted_files_list,
remaining_files_list,
file_dict,
files_in_patches_list,
) = get_pr_diff_multiple_patchs(
self.git_provider, token_handler_only_files_prompt, model
)
# get the files prediction for each patch # get the files prediction for each patch
if not get_settings().pr_description.async_ai_calls: if not get_settings().pr_description.async_ai_calls:
@ -241,8 +343,9 @@ class PRDescription:
for i, patches in enumerate(patches_compressed_list): # sync calls for i, patches in enumerate(patches_compressed_list): # sync calls
patches_diff = "\n".join(patches) patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files") get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff, prediction_files = await self._get_prediction(
prompt="pr_description_only_files_prompts") model, patches_diff, prompt="pr_description_only_files_prompts"
)
results.append(prediction_files) results.append(prediction_files)
else: # async calls else: # async calls
tasks = [] tasks = []
@ -251,34 +354,52 @@ class PRDescription:
patches_diff = "\n".join(patches) patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files") get_logger().debug(f"PR diff number {i + 1} for describe files")
task = asyncio.create_task( task = asyncio.create_task(
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts")) self._get_prediction(
model,
patches_diff,
prompt="pr_description_only_files_prompts",
)
)
tasks.append(task) tasks.append(task)
# Wait for all tasks to complete # Wait for all tasks to complete
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
file_description_str_list = [] file_description_str_list = []
for i, result in enumerate(results): for i, result in enumerate(results):
prediction_files = result.strip().removeprefix('```yaml').strip('`').strip() prediction_files = (
if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'): result.strip().removeprefix('```yaml').strip('`').strip()
prediction_files = prediction_files.removeprefix('pr_files:').strip() )
if load_yaml(
prediction_files, keys_fix_yaml=self.keys_fix
) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix(
'pr_files:'
).strip()
file_description_str_list.append(prediction_files) file_description_str_list.append(prediction_files)
else: else:
get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files") get_logger().debug(
f"failed to generate predictions in iteration {i + 1} for describe files"
)
# generate files_walkthrough string, with proper token handling # generate files_walkthrough string, with proper token handling
token_handler_only_description_prompt = TokenHandler( token_handler_only_description_prompt = TokenHandler(
self.git_provider.pr, self.git_provider.pr,
self.vars, self.vars,
get_settings().pr_description_only_description_prompts.system, get_settings().pr_description_only_description_prompts.system,
get_settings().pr_description_only_description_prompts.user) get_settings().pr_description_only_description_prompts.user,
)
files_walkthrough = "\n".join(file_description_str_list) files_walkthrough = "\n".join(file_description_str_list)
files_walkthrough_prompt = copy.deepcopy(files_walkthrough) files_walkthrough_prompt = copy.deepcopy(files_walkthrough)
MAX_EXTRA_FILES_TO_PROMPT = 50 MAX_EXTRA_FILES_TO_PROMPT = 50
if remaining_files_list: if remaining_files_list:
files_walkthrough_prompt += "\n\nNo more token budget. Additional unprocessed files:" files_walkthrough_prompt += (
"\n\nNo more token budget. Additional unprocessed files:"
)
for i, file in enumerate(remaining_files_list): for i, file in enumerate(remaining_files_list):
files_walkthrough_prompt += f"\n- {file}" files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT: if i >= MAX_EXTRA_FILES_TO_PROMPT:
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}") get_logger().debug(
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
)
files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more" files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
break break
if deleted_files_list: if deleted_files_list:
@ -286,32 +407,57 @@ class PRDescription:
for i, file in enumerate(deleted_files_list): for i, file in enumerate(deleted_files_list):
files_walkthrough_prompt += f"\n- {file}" files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT: if i >= MAX_EXTRA_FILES_TO_PROMPT:
get_logger().debug(f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}") get_logger().debug(
f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
)
files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more" files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
break break
tokens_files_walkthrough = len( tokens_files_walkthrough = len(
token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt)) token_handler_only_description_prompt.encoder.encode(
total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough files_walkthrough_prompt
)
)
total_tokens = (
token_handler_only_description_prompt.prompt_tokens
+ tokens_files_walkthrough
)
max_tokens_model = get_max_tokens(model) max_tokens_model = get_max_tokens(model)
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD: if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
# clip files_walkthrough to git the tokens within the limit # clip files_walkthrough to git the tokens within the limit
files_walkthrough_prompt = clip_tokens(files_walkthrough_prompt, files_walkthrough_prompt = clip_tokens(
max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens, files_walkthrough_prompt,
num_input_tokens=tokens_files_walkthrough) max_tokens_model
- OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
- token_handler_only_description_prompt.prompt_tokens,
num_input_tokens=tokens_files_walkthrough,
)
# PR header inference # PR header inference
get_logger().debug(f"PR diff only description", artifact=files_walkthrough_prompt) get_logger().debug(
prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough_prompt, f"PR diff only description", artifact=files_walkthrough_prompt
prompt="pr_description_only_description_prompts") )
prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip() prediction_headers = await self._get_prediction(
model,
patches_diff=files_walkthrough_prompt,
prompt="pr_description_only_description_prompts",
)
prediction_headers = (
prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
)
# extend the tables with the files not shown # extend the tables with the files not shown
files_walkthrough_extended = await self.extend_uncovered_files(files_walkthrough) files_walkthrough_extended = await self.extend_uncovered_files(
files_walkthrough
)
# final processing # final processing
self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended self.prediction = (
prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
)
if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix): if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix):
get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}") get_logger().error(
f"Error getting valid YAML in large PR handling for describe {self.pr_id}"
)
if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix): if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix):
get_logger().debug(f"Using only headers for describe {self.pr_id}") get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers self.prediction = prediction_headers
@ -321,12 +467,17 @@ class PRDescription:
prediction = original_prediction prediction = original_prediction
# get the original prediction filenames # get the original prediction filenames
original_prediction_loaded = load_yaml(original_prediction, keys_fix_yaml=self.keys_fix) original_prediction_loaded = load_yaml(
original_prediction, keys_fix_yaml=self.keys_fix
)
if isinstance(original_prediction_loaded, list): if isinstance(original_prediction_loaded, list):
original_prediction_dict = {"pr_files": original_prediction_loaded} original_prediction_dict = {"pr_files": original_prediction_loaded}
else: else:
original_prediction_dict = original_prediction_loaded original_prediction_dict = original_prediction_loaded
filenames_predicted = [file['filename'].strip() for file in original_prediction_dict.get('pr_files', [])] filenames_predicted = [
file['filename'].strip()
for file in original_prediction_dict.get('pr_files', [])
]
# extend the prediction with additional files not included in the original prediction # extend the prediction with additional files not included in the original prediction
pr_files = self.git_provider.get_diff_files() pr_files = self.git_provider.get_diff_files()
@ -349,7 +500,9 @@ class PRDescription:
additional files additional files
""" """
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip() prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}") get_logger().debug(
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}"
)
break break
extra_file_yaml = f"""\ extra_file_yaml = f"""\
@ -364,10 +517,18 @@ class PRDescription:
# merge the two dictionaries # merge the two dictionaries
if counter_extra_files > 0: if counter_extra_files > 0:
get_logger().info(f"Adding {counter_extra_files} unprocessed extra files to table prediction") get_logger().info(
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix) f"Adding {counter_extra_files} unprocessed extra files to table prediction"
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict): )
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"]) prediction_extra_dict = load_yaml(
prediction_extra, keys_fix_yaml=self.keys_fix
)
if isinstance(original_prediction_dict, dict) and isinstance(
prediction_extra_dict, dict
):
original_prediction_dict["pr_files"].extend(
prediction_extra_dict["pr_files"]
)
new_yaml = yaml.dump(original_prediction_dict) new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix): if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
prediction = new_yaml prediction = new_yaml
@ -379,11 +540,12 @@ class PRDescription:
get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}") get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}")
return original_prediction return original_prediction
async def extend_additional_files(self, remaining_files_list) -> str: async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction prediction = self.prediction
try: try:
original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix) original_prediction_dict = load_yaml(
self.prediction, keys_fix_yaml=self.keys_fix
)
prediction_extra = "pr_files:" prediction_extra = "pr_files:"
for file in remaining_files_list: for file in remaining_files_list:
extra_file_yaml = f"""\ extra_file_yaml = f"""\
@ -397,10 +559,16 @@ class PRDescription:
additional files (token-limit) additional files (token-limit)
""" """
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip() prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix) prediction_extra_dict = load_yaml(
prediction_extra, keys_fix_yaml=self.keys_fix
)
# merge the two dictionaries # merge the two dictionaries
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict): if isinstance(original_prediction_dict, dict) and isinstance(
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"]) prediction_extra_dict, dict
):
original_prediction_dict["pr_files"].extend(
prediction_extra_dict["pr_files"]
)
new_yaml = yaml.dump(original_prediction_dict) new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix): if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
prediction = new_yaml prediction = new_yaml
@ -409,7 +577,9 @@ class PRDescription:
get_logger().error(f"Error extending additional files {self.pr_id}: {e}") get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
return self.prediction return self.prediction
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: async def _get_prediction(
self, model: str, patches_diff: str, prompt="pr_description_prompt"
) -> str:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff variables["diff"] = patches_diff # update diff
@ -417,14 +587,18 @@ class PRDescription:
set_custom_labels(variables, self.git_provider) set_custom_labels(variables, self.git_provider)
self.variables = variables self.variables = variables
system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(self.variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(self.variables) get_settings().get(prompt, {}).get("system", "")
).render(self.variables)
user_prompt = environment.from_string(
get_settings().get(prompt, {}).get("user", "")
).render(self.variables)
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, model=model,
temperature=get_settings().config.temperature, temperature=get_settings().config.temperature,
system=system_prompt, system=system_prompt,
user=user_prompt user=user_prompt,
) )
return response return response
@ -433,7 +607,10 @@ class PRDescription:
# Load the AI prediction data into a dictionary # Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix) self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix)
if get_settings().pr_description.add_original_user_description and self.user_description: if (
get_settings().pr_description.add_original_user_description
and self.user_description
):
self.data["User Description"] = self.user_description self.data["User Description"] = self.user_description
# re-order keys # re-order keys
@ -459,7 +636,11 @@ class PRDescription:
pr_labels = self.data['labels'] pr_labels = self.data['labels']
elif type(self.data['labels']) == str: elif type(self.data['labels']) == str:
pr_labels = self.data['labels'].split(',') pr_labels = self.data['labels'].split(',')
elif 'type' in self.data and self.data['type'] and get_settings().pr_description.publish_labels: elif (
'type' in self.data
and self.data['type']
and get_settings().pr_description.publish_labels
):
if type(self.data['type']) == list: if type(self.data['type']) == list:
pr_labels = self.data['type'] pr_labels = self.data['type']
elif type(self.data['type']) == str: elif type(self.data['type']) == str:
@ -474,7 +655,9 @@ class PRDescription:
if label_i in d: if label_i in d:
pr_labels[i] = d[label_i] pr_labels[i] = d[label_i]
except Exception as e: except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}") get_logger().error(
f"Error converting labels to original case {self.pr_id}: {e}"
)
return pr_labels return pr_labels
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]: def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
@ -482,7 +665,7 @@ class PRDescription:
# Remove the 'PR Title' key from the dictionary # Remove the 'PR Title' key from the dictionary
ai_title = self.data.pop('title', self.vars["title"]) ai_title = self.data.pop('title', self.vars["title"])
if (not get_settings().pr_description.generate_ai_title): if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable # Assign the original PR title to the 'title' variable
title = self.vars["title"] title = self.vars["title"]
else: else:
@ -514,8 +697,9 @@ class PRDescription:
pr_file_changes = [] pr_file_changes = []
if ai_walkthrough and not re.search(r'<!--\s*pr_agent:walkthrough\s*-->', body): if ai_walkthrough and not re.search(r'<!--\s*pr_agent:walkthrough\s*-->', body):
try: try:
walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(walkthrough_gfm, walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(
self.file_label_dict) walkthrough_gfm, self.file_label_dict
)
body = body.replace('pr_agent:walkthrough', walkthrough_gfm) body = body.replace('pr_agent:walkthrough', walkthrough_gfm)
except Exception as e: except Exception as e:
get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}") get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}")
@ -545,7 +729,7 @@ class PRDescription:
# Remove the 'PR Title' key from the dictionary # Remove the 'PR Title' key from the dictionary
ai_title = self.data.pop('title', self.vars["title"]) ai_title = self.data.pop('title', self.vars["title"])
if (not get_settings().pr_description.generate_ai_title): if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable # Assign the original PR title to the 'title' variable
title = self.vars["title"] title = self.vars["title"]
else: else:
@ -575,13 +759,20 @@ class PRDescription:
pr_body += f'- `{filename}`: {description}\n' pr_body += f'- `{filename}`: {description}\n'
if self.git_provider.is_supported("gfm_markdown"): if self.git_provider.is_supported("gfm_markdown"):
pr_body += "</details>\n" pr_body += "</details>\n"
elif 'pr_files' in key.lower() and get_settings().pr_description.enable_semantic_files_types: elif (
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(changes_walkthrough, value) 'pr_files' in key.lower()
and get_settings().pr_description.enable_semantic_files_types
):
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(
changes_walkthrough, value
)
changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}" changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}"
elif key.lower().strip() == 'description': elif key.lower().strip() == 'description':
if isinstance(value, list): if isinstance(value, list):
value = ', '.join(v.rstrip() for v in value) value = ', '.join(v.rstrip() for v in value)
value = value.replace('\n-', '\n\n-').strip() # makes the bullet points more readable by adding double space value = value.replace(
'\n-', '\n\n-'
).strip() # makes the bullet points more readable by adding double space
pr_body += f"{value}\n" pr_body += f"{value}\n"
else: else:
# if the value is a list, join its items by comma # if the value is a list, join its items by comma
@ -591,24 +782,37 @@ class PRDescription:
if idx < len(self.data) - 1: if idx < len(self.data) - 1:
pr_body += "\n\n___\n\n" pr_body += "\n\n___\n\n"
return title, pr_body, changes_walkthrough, pr_file_changes, return (
title,
pr_body,
changes_walkthrough,
pr_file_changes,
)
def _prepare_file_labels(self): def _prepare_file_labels(self):
file_label_dict = {} file_label_dict = {}
if (not self.data or not isinstance(self.data, dict) or if (
'pr_files' not in self.data or not self.data['pr_files']): not self.data
or not isinstance(self.data, dict)
or 'pr_files' not in self.data
or not self.data['pr_files']
):
return file_label_dict return file_label_dict
for file in self.data['pr_files']: for file in self.data['pr_files']:
try: try:
required_fields = ['changes_title', 'filename', 'label'] required_fields = ['changes_title', 'filename', 'label']
if not all(field in file for field in required_fields): if not all(field in file for field in required_fields):
# can happen for example if a YAML generation was interrupted in the middle (no more tokens) # can happen for example if a YAML generation was interrupted in the middle (no more tokens)
get_logger().warning(f"Missing required fields in file label dict {self.pr_id}, skipping file", get_logger().warning(
artifact={"file": file}) f"Missing required fields in file label dict {self.pr_id}, skipping file",
artifact={"file": file},
)
continue continue
if not file.get('changes_title'): if not file.get('changes_title'):
get_logger().warning(f"Empty changes title or summary in file label dict {self.pr_id}, skipping file", get_logger().warning(
artifact={"file": file}) f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
artifact={"file": file},
)
continue continue
filename = file['filename'].replace("'", "`").replace('"', '`') filename = file['filename'].replace("'", "`").replace('"', '`')
changes_summary = file.get('changes_summary', "").strip() changes_summary = file.get('changes_summary', "").strip()
@ -616,7 +820,9 @@ class PRDescription:
label = file.get('label').strip().lower() label = file.get('label').strip().lower()
if label not in file_label_dict: if label not in file_label_dict:
file_label_dict[label] = [] file_label_dict[label] = []
file_label_dict[label].append((filename, changes_title, changes_summary)) file_label_dict[label].append(
(filename, changes_title, changes_summary)
)
except Exception as e: except Exception as e:
get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}") get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}")
pass pass
@ -640,7 +846,9 @@ class PRDescription:
header = f"相关文件" header = f"相关文件"
delta = 75 delta = 75
# header += "&nbsp; " * delta # header += "&nbsp; " * delta
pr_body += f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>""" pr_body += (
f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
)
pr_body += """<tbody>""" pr_body += """<tbody>"""
for semantic_label in value.keys(): for semantic_label in value.keys():
s_label = semantic_label.strip("'").strip('"') s_label = semantic_label.strip("'").strip('"')
@ -651,14 +859,22 @@ class PRDescription:
pr_body += f"""<td><details><summary>{len(list_tuples)} files</summary><table>""" pr_body += f"""<td><details><summary>{len(list_tuples)} files</summary><table>"""
else: else:
pr_body += f"""<td><table>""" pr_body += f"""<td><table>"""
for filename, file_changes_title, file_change_description in list_tuples: for (
filename,
file_changes_title,
file_change_description,
) in list_tuples:
filename = filename.replace("'", "`").rstrip() filename = filename.replace("'", "`").rstrip()
filename_publish = filename.split("/")[-1] filename_publish = filename.split("/")[-1]
if file_changes_title and file_changes_title.strip() != "...": if file_changes_title and file_changes_title.strip() != "...":
file_changes_title_code = f"<code>{file_changes_title}</code>" file_changes_title_code = f"<code>{file_changes_title}</code>"
file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip() file_changes_title_code_br = insert_br_after_x_chars(
file_changes_title_code, x=(delta - 5)
).strip()
if len(file_changes_title_code_br) < (delta - 5): if len(file_changes_title_code_br) < (delta - 5):
file_changes_title_code_br += "&nbsp; " * ((delta - 5) - len(file_changes_title_code_br)) file_changes_title_code_br += "&nbsp; " * (
(delta - 5) - len(file_changes_title_code_br)
)
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>" filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
else: else:
filename_publish = f"<strong>{filename_publish}</strong>" filename_publish = f"<strong>{filename_publish}</strong>"
@ -679,15 +895,30 @@ class PRDescription:
link = "" link = ""
if hasattr(self.git_provider, 'get_line_link'): if hasattr(self.git_provider, 'get_line_link'):
filename = filename.strip() filename = filename.strip()
link = self.git_provider.get_line_link(filename, relevant_line_start=-1) link = self.git_provider.get_line_link(
if (not link or not diff_plus_minus) and ('additional files' not in filename.lower()): filename, relevant_line_start=-1
get_logger().warning(f"Error getting line link for '{filename}'") )
if (not link or not diff_plus_minus) and (
'additional files' not in filename.lower()
):
get_logger().warning(
f"Error getting line link for '{filename}'"
)
continue continue
# Add file data to the PR body # Add file data to the PR body
file_change_description_br = insert_br_after_x_chars(file_change_description, x=(delta - 5)) file_change_description_br = insert_br_after_x_chars(
pr_body = self.add_file_data(delta_nbsp, diff_plus_minus, file_change_description_br, filename, file_change_description, x=(delta - 5)
filename_publish, link, pr_body) )
pr_body = self.add_file_data(
delta_nbsp,
diff_plus_minus,
file_change_description_br,
filename,
filename_publish,
link,
pr_body,
)
# Close the collapsible file list # Close the collapsible file list
if use_collapsible_file_list: if use_collapsible_file_list:
@ -697,13 +928,22 @@ class PRDescription:
pr_body += """</tr></tbody></table>""" pr_body += """</tr></tbody></table>"""
except Exception as e: except Exception as e:
get_logger().error(f"Error processing pr files to markdown {self.pr_id}: {str(e)}") get_logger().error(
f"Error processing pr files to markdown {self.pr_id}: {str(e)}"
)
pass pass
return pr_body, pr_comments return pr_body, pr_comments
def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link, def add_file_data(
pr_body) -> str: self,
delta_nbsp,
diff_plus_minus,
file_change_description_br,
filename,
filename_publish,
link,
pr_body,
) -> str:
if not file_change_description_br: if not file_change_description_br:
pr_body += f""" pr_body += f"""
<tr> <tr>
@ -735,6 +975,7 @@ class PRDescription:
""" """
return pr_body return pr_body
def count_chars_without_html(string): def count_chars_without_html(string):
if '<' not in string: if '<' not in string:
return len(string) return len(string)

View File

@ -16,8 +16,12 @@ from utils.pr_agent.log import get_logger
class PRGenerateLabels: class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None, def __init__(
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self,
pr_url: str,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
""" """
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model. corresponding to the PR using an AI model.
@ -93,7 +97,9 @@ class PRGenerateLabels:
elif pr_labels: elif pr_labels:
value = ', '.join(v for v in pr_labels) value = ', '.join(v for v in pr_labels)
pr_labels_text = f"## PR Labels:\n{value}\n" pr_labels_text = f"## PR Labels:\n{value}\n"
self.git_provider.publish_comment(pr_labels_text, is_temporary=False) self.git_provider.publish_comment(
pr_labels_text, is_temporary=False
)
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
except Exception as e: except Exception as e:
get_logger().error(f"Error generating PR labels {self.pr_id}: {e}") get_logger().error(f"Error generating PR labels {self.pr_id}: {e}")
@ -137,14 +143,18 @@ class PRGenerateLabels:
set_custom_labels(variables, self.git_provider) set_custom_labels(variables, self.git_provider)
self.variables = variables self.variables = variables
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(self.variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(self.variables) get_settings().pr_custom_labels_prompt.system
).render(self.variables)
user_prompt = environment.from_string(
get_settings().pr_custom_labels_prompt.user
).render(self.variables)
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, model=model,
temperature=get_settings().config.temperature, temperature=get_settings().config.temperature,
system=system_prompt, system=system_prompt,
user=user_prompt user=user_prompt,
) )
return response return response
@ -153,8 +163,6 @@ class PRGenerateLabels:
# Load the AI prediction data into a dictionary # Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip()) self.data = load_yaml(self.prediction.strip())
def _prepare_labels(self) -> List[str]: def _prepare_labels(self) -> List[str]:
pr_types = [] pr_types = []
@ -174,6 +182,8 @@ class PRGenerateLabels:
if label_i in d: if label_i in d:
pr_types[i] = d[label_i] pr_types[i] = d[label_i]
except Exception as e: except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}") get_logger().error(
f"Error converting labels to original case {self.pr_id}: {e}"
)
return pr_types return pr_types

View File

@ -12,7 +12,11 @@ from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context from utils.pr_agent.git_providers import (
BitbucketServerProvider,
GithubProvider,
get_git_provider_with_context,
)
from utils.pr_agent.log import get_logger from utils.pr_agent.log import get_logger
@ -29,31 +33,50 @@ def extract_header(snippet):
res = f"#{highest_header.lower().replace(' ', '-')}" res = f"#{highest_header.lower().replace(' ', '-')}"
return res return res
class PRHelpMessage: class PRHelpMessage:
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False): def __init__(
self,
pr_url: str,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
return_as_string=False,
):
self.git_provider = get_git_provider_with_context(pr_url) self.git_provider = get_git_provider_with_context(pr_url)
self.ai_handler = ai_handler() self.ai_handler = ai_handler()
self.question_str = self.parse_args(args) self.question_str = self.parse_args(args)
self.return_as_string = return_as_string self.return_as_string = return_as_string
self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5) self.num_retrieved_snippets = get_settings().get(
'pr_help.num_retrieved_snippets', 5
)
if self.question_str: if self.question_str:
self.vars = { self.vars = {
"question": self.question_str, "question": self.question_str,
"snippets": "", "snippets": "",
} }
self.token_handler = TokenHandler(None, self.token_handler = TokenHandler(
None,
self.vars, self.vars,
get_settings().pr_help_prompts.system, get_settings().pr_help_prompts.system,
get_settings().pr_help_prompts.user) get_settings().pr_help_prompts.user,
)
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str):
try: try:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_help_prompts.system).render(variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().pr_help_prompts.user).render(variables) get_settings().pr_help_prompts.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_help_prompts.user
).render(variables)
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response return response
except Exception as e: except Exception as e:
get_logger().error(f"Error while preparing prediction: {e}") get_logger().error(f"Error while preparing prediction: {e}")
@ -81,7 +104,7 @@ class PRHelpMessage:
'.': '', '.': '',
'?': '', '?': '',
'!': '', '!': '',
' ': '-' ' ': '-',
} }
# Compile regex pattern for characters to remove # Compile regex pattern for characters to remove
@ -90,21 +113,27 @@ class PRHelpMessage:
# Perform replacements in a single pass and convert to lowercase # Perform replacements in a single pass and convert to lowercase
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower() return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
except Exception: except Exception:
get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header}) get_logger().exception(
f"Error while formatting markdown header", artifacts={'header': header}
)
return "" return ""
async def run(self): async def run(self):
try: try:
if self.question_str: if self.question_str:
get_logger().info(f'Answering a PR question about the PR {self.git_provider.pr_url} ') get_logger().info(
f'Answering a PR question about the PR {self.git_provider.pr_url} '
)
if not get_settings().get('openai.key'): if not get_settings().get('openai.key'):
if get_settings().config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment( self.git_provider.publish_comment(
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings") "The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
)
else: else:
get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings") get_logger().error(
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
)
return return
# current path # current path
@ -112,15 +141,41 @@ class PRHelpMessage:
# get all the 'md' files inside docs_path and its subdirectories # get all the 'md' files inside docs_path and its subdirectories
md_files = list(docs_path.glob('**/*.md')) md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude = ['/finetuning_benchmark/'] folders_to_exclude = ['/finetuning_benchmark/']
files_to_exclude = {'EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md'} files_to_exclude = {
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)] 'EXAMPLE_BEST_PRACTICE.md',
'compression_strategy.md',
'/docs/overview/index.md',
}
md_files = [
file
for file in md_files
if not any(folder in str(file) for folder in folders_to_exclude)
and not any(
file.name == file_to_exclude
for file_to_exclude in files_to_exclude
)
]
# sort the 'md_files' so that 'priority_files' will be at the top # sort the 'md_files' so that 'priority_files' will be at the top
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md', priority_files_strings = [
'tools/improve.md', '/faq'] '/docs/index.md',
md_files_priority = [file for file in md_files if '/usage-guide',
any(priority_string in str(file) for priority_string in priority_files_strings)] 'tools/describe.md',
md_files_not_priority = [file for file in md_files if file not in md_files_priority] 'tools/review.md',
'tools/improve.md',
'/faq',
]
md_files_priority = [
file
for file in md_files
if any(
priority_string in str(file)
for priority_string in priority_files_strings
)
]
md_files_not_priority = [
file for file in md_files if file not in md_files_priority
]
md_files = md_files_priority + md_files_not_priority md_files = md_files_priority + md_files_not_priority
docs_prompt = "" docs_prompt = ""
@ -132,24 +187,36 @@ class PRHelpMessage:
except Exception as e: except Exception as e:
get_logger().error(f"Error while reading the file {file}: {e}") get_logger().error(f"Error while reading the file {file}: {e}")
token_count = self.token_handler.count_tokens(docs_prompt) token_count = self.token_handler.count_tokens(docs_prompt)
get_logger().debug(f"Token count of full documentation website: {token_count}") get_logger().debug(
f"Token count of full documentation website: {token_count}"
)
model = get_settings().config.model model = get_settings().config.model
if model in MAX_TOKENS: if model in MAX_TOKENS:
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt max_tokens_full = MAX_TOKENS[
model
] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
else: else:
max_tokens_full = get_max_tokens(model) max_tokens_full = get_max_tokens(model)
delta_output = 2000 delta_output = 2000
if token_count > max_tokens_full - delta_output: if token_count > max_tokens_full - delta_output:
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.") get_logger().info(
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output) f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message."
)
docs_prompt = clip_tokens(
docs_prompt, max_tokens_full - delta_output
)
self.vars['snippets'] = docs_prompt.strip() self.vars['snippets'] = docs_prompt.strip()
# run the AI model # run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) response = await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.REGULAR
)
response_yaml = load_yaml(response) response_yaml = load_yaml(response)
if isinstance(response_yaml, str): if isinstance(response_yaml, str):
get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is") get_logger().warning(
f"failing to parse response: {response_yaml}, publishing the response as is"
)
if get_settings().config.publish_output: if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n" answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n" answer_str += f"### Answer:\n\n"
@ -160,7 +227,9 @@ class PRHelpMessage:
relevant_sections = response_yaml.get('relevant_sections') relevant_sections = response_yaml.get('relevant_sections')
if not relevant_sections: if not relevant_sections:
get_logger().info(f"Could not find relevant answer for the question: {self.question_str}") get_logger().info(
f"Could not find relevant answer for the question: {self.question_str}"
)
if get_settings().config.publish_output: if get_settings().config.publish_output:
answer_str = f"### Question: \n{self.question_str}\n\n" answer_str = f"### Question: \n{self.question_str}\n\n"
answer_str += f"### Answer:\n\n" answer_str += f"### Answer:\n\n"
@ -178,29 +247,38 @@ class PRHelpMessage:
for section in relevant_sections: for section in relevant_sections:
file = section.get('file_name').strip().removesuffix('.md') file = section.get('file_name').strip().removesuffix('.md')
if str(section['relevant_section_header_string']).strip(): if str(section['relevant_section_header_string']).strip():
markdown_header = self.format_markdown_header(section['relevant_section_header_string']) markdown_header = self.format_markdown_header(
section['relevant_section_header_string']
)
answer_str += f"> - {base_path}{file}#{markdown_header}\n" answer_str += f"> - {base_path}{file}#{markdown_header}\n"
else: else:
answer_str += f"> - {base_path}{file}\n" answer_str += f"> - {base_path}{file}\n"
# publish the answer # publish the answer
if get_settings().config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment(answer_str) self.git_provider.publish_comment(answer_str)
else: else:
get_logger().info(f"Answer:\n{answer_str}") get_logger().info(f"Answer:\n{answer_str}")
else: else:
if not isinstance(self.git_provider, BitbucketServerProvider) and not self.git_provider.is_supported("gfm_markdown"): if not isinstance(
self.git_provider, BitbucketServerProvider
) and not self.git_provider.is_supported("gfm_markdown"):
self.git_provider.publish_comment( self.git_provider.publish_comment(
"The `Help` tool requires gfm markdown, which is not supported by your code platform.") "The `Help` tool requires gfm markdown, which is not supported by your code platform."
)
return return
get_logger().info('Getting PR Help Message...') get_logger().info('Getting PR Help Message...')
relevant_configs = {'pr_help': dict(get_settings().pr_help), relevant_configs = {
'config': dict(get_settings().config)} 'pr_help': dict(get_settings().pr_help),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs) get_logger().debug("Relevant configs", artifacts=relevant_configs)
pr_comment = "## PR Agent Walkthrough 🤖\n\n" pr_comment = "## PR Agent Walkthrough 🤖\n\n"
pr_comment += "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more.""" pr_comment += (
"Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."
""
)
pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n" pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n"
base_path = "https://pr-agent-docs.codium.ai/tools" base_path = "https://pr-agent-docs.codium.ai/tools"
@ -211,30 +289,56 @@ class PRHelpMessage:
tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)") tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)")
tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎") tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎")
tool_names.append(f"[TEST]({base_path}/test/) 💎") tool_names.append(f"[TEST]({base_path}/test/) 💎")
tool_names.append(f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎") tool_names.append(
f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎"
)
tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎") tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎")
tool_names.append(f"[ASK]({base_path}/ask/)") tool_names.append(f"[ASK]({base_path}/ask/)")
tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)") tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)")
tool_names.append(f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎") tool_names.append(
f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎"
)
tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎") tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎")
tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎") tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎")
tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎") tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎")
descriptions = [] descriptions = []
descriptions.append("Generates PR description - title, type, summary, code walkthrough and labels") descriptions.append(
descriptions.append("Adjustable feedback about the PR, possible issues, security concerns, review effort and more") "Generates PR description - title, type, summary, code walkthrough and labels"
)
descriptions.append(
"Adjustable feedback about the PR, possible issues, security concerns, review effort and more"
)
descriptions.append("Code suggestions for improving the PR") descriptions.append("Code suggestions for improving the PR")
descriptions.append("Automatically updates the changelog") descriptions.append("Automatically updates the changelog")
descriptions.append("Generates documentation to methods/functions/classes that changed in the PR") descriptions.append(
descriptions.append("Generates unit tests for a specific component, based on the PR code change") "Generates documentation to methods/functions/classes that changed in the PR"
descriptions.append("Code suggestions for a specific component that changed in the PR") )
descriptions.append("Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component") descriptions.append(
"Generates unit tests for a specific component, based on the PR code change"
)
descriptions.append(
"Code suggestions for a specific component that changed in the PR"
)
descriptions.append(
"Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component"
)
descriptions.append("Answering free-text questions about the PR") descriptions.append("Answering free-text questions about the PR")
descriptions.append("Automatically retrieves and presents similar issues") descriptions.append(
descriptions.append("Generates custom labels for the PR, based on specific guidelines defined by the user") "Automatically retrieves and presents similar issues"
descriptions.append("Generates feedback and analysis for a failed CI job") )
descriptions.append("Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user") descriptions.append(
descriptions.append("Generates implementation code from review suggestions") "Generates custom labels for the PR, based on specific guidelines defined by the user"
)
descriptions.append(
"Generates feedback and analysis for a failed CI job"
)
descriptions.append(
"Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user"
)
descriptions.append(
"Generates implementation code from review suggestions"
)
commands = [] commands = []
commands.append("`/describe`") commands.append("`/describe`")
@ -271,7 +375,9 @@ class PRHelpMessage:
checkbox_list.append("[*]") checkbox_list.append("[*]")
checkbox_list.append("[*]") checkbox_list.append("[*]")
if isinstance(self.git_provider, GithubProvider) and not get_settings().config.get('disable_checkboxes', False): if isinstance(
self.git_provider, GithubProvider
) and not get_settings().config.get('disable_checkboxes', False):
pr_comment += f"<table><tr align='left'><th align='left'>Tool</th><th align='left'>Description</th><th align='left'>Trigger Interactively :gem:</th></tr>" pr_comment += f"<table><tr align='left'><th align='left'>Tool</th><th align='left'>Description</th><th align='left'>Trigger Interactively :gem:</th></tr>"
for i in range(len(tool_names)): for i in range(len(tool_names)):
pr_comment += f"\n<tr><td align='left'>\n\n<strong>{tool_names[i]}</strong></td>\n<td>{descriptions[i]}</td>\n<td>\n\n{checkbox_list[i]}\n</td></tr>" pr_comment += f"\n<tr><td align='left'>\n\n<strong>{tool_names[i]}</strong></td>\n<td>{descriptions[i]}</td>\n<td>\n\n{checkbox_list[i]}\n</td></tr>"

View File

@ -5,8 +5,7 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.git_patch_processing import ( from utils.pr_agent.algo.git_patch_processing import extract_hunk_lines_from_patch
extract_hunk_lines_from_patch)
from utils.pr_agent.algo.pr_processing import retry_with_fallback_models from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import ModelType from utils.pr_agent.algo.utils import ModelType
@ -17,7 +16,12 @@ from utils.pr_agent.log import get_logger
class PR_LineQuestions: class PR_LineQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): def __init__(
self,
pr_url: str,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
self.question_str = self.parse_args(args) self.question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
@ -34,10 +38,12 @@ class PR_LineQuestions:
"full_hunk": "", "full_hunk": "",
"selected_lines": "", "selected_lines": "",
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars, self.vars,
get_settings().pr_line_questions_prompt.system, get_settings().pr_line_questions_prompt.system,
get_settings().pr_line_questions_prompt.user) get_settings().pr_line_questions_prompt.user,
)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -48,7 +54,6 @@ class PR_LineQuestions:
question_str = "" question_str = ""
return question_str return question_str
async def run(self): async def run(self):
get_logger().info('Answering a PR lines question...') get_logger().info('Answering a PR lines question...')
# if get_settings().config.publish_output: # if get_settings().config.publish_output:
@ -62,22 +67,27 @@ class PR_LineQuestions:
file_name = get_settings().get('file_name', '') file_name = get_settings().get('file_name', '')
comment_id = get_settings().get('comment_id', '') comment_id = get_settings().get('comment_id', '')
if ask_diff: if ask_diff:
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(ask_diff, self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(
file_name, ask_diff, file_name, line_start=line_start, line_end=line_end, side=side
line_start=line_start,
line_end=line_end,
side=side
) )
else: else:
diff_files = self.git_provider.get_diff_files() diff_files = self.git_provider.get_diff_files()
for file in diff_files: for file in diff_files:
if file.filename == file_name: if file.filename == file_name:
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(file.patch, file.filename, (
self.patch_with_lines,
self.selected_lines,
) = extract_hunk_lines_from_patch(
file.patch,
file.filename,
line_start=line_start, line_start=line_start,
line_end=line_end, line_end=line_end,
side=side) side=side,
)
if self.patch_with_lines: if self.patch_with_lines:
model_answer = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK) model_answer = await retry_with_fallback_models(
self._get_prediction, model_type=ModelType.WEAK
)
# sanitize the answer so that no line will start with "/" # sanitize the answer so that no line will start with "/"
model_answer_sanitized = model_answer.strip().replace("\n/", "\n /") model_answer_sanitized = model_answer.strip().replace("\n/", "\n /")
if model_answer_sanitized.startswith("/"): if model_answer_sanitized.startswith("/"):
@ -85,7 +95,9 @@ class PR_LineQuestions:
get_logger().info('Preparing answer...') get_logger().info('Preparing answer...')
if comment_id: if comment_id:
self.git_provider.reply_to_comment_from_comment_id(comment_id, model_answer_sanitized) self.git_provider.reply_to_comment_from_comment_id(
comment_id, model_answer_sanitized
)
else: else:
self.git_provider.publish_comment(model_answer_sanitized) self.git_provider.publish_comment(model_answer_sanitized)
@ -96,8 +108,12 @@ class PR_LineQuestions:
variables["full_hunk"] = self.patch_with_lines # update diff variables["full_hunk"] = self.patch_with_lines # update diff
variables["selected_lines"] = self.selected_lines variables["selected_lines"] = self.selected_lines
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_line_questions_prompt.system).render(variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().pr_line_questions_prompt.user).render(variables) get_settings().pr_line_questions_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_line_questions_prompt.user
).render(variables)
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
# get_logger().info(f"\nSystem prompt:\n{system_prompt}") # get_logger().info(f"\nSystem prompt:\n{system_prompt}")
# get_logger().info(f"\nUser prompt:\n{user_prompt}") # get_logger().info(f"\nUser prompt:\n{user_prompt}")
@ -105,5 +121,9 @@ class PR_LineQuestions:
print(f"\nUser prompt:\n{user_prompt}") print(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response return response

View File

@ -16,7 +16,12 @@ from utils.pr_agent.servers.help import HelpMessage
class PRQuestions: class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): def __init__(
self,
pr_url: str,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
question_str = self.parse_args(args) question_str = self.parse_args(args)
self.pr_url = pr_url self.pr_url = pr_url
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
@ -36,10 +41,12 @@ class PRQuestions:
"questions": self.question_str, "questions": self.question_str,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars, self.vars,
get_settings().pr_questions_prompt.system, get_settings().pr_questions_prompt.system,
get_settings().pr_questions_prompt.user) get_settings().pr_questions_prompt.user,
)
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -52,8 +59,10 @@ class PRQuestions:
async def run(self): async def run(self):
get_logger().info(f'Answering a PR question about the PR {self.pr_url} ') get_logger().info(f'Answering a PR question about the PR {self.pr_url} ')
relevant_configs = {'pr_questions': dict(get_settings().pr_questions), relevant_configs = {
'config': dict(get_settings().config)} 'pr_questions': dict(get_settings().pr_questions),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs) get_logger().debug("Relevant configs", artifacts=relevant_configs)
if get_settings().config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("思考回答中...", is_temporary=True) self.git_provider.publish_comment("思考回答中...", is_temporary=True)
@ -63,12 +72,17 @@ class PRQuestions:
if img_path: if img_path:
get_logger().debug(f"Image path identified", artifact=img_path) get_logger().debug(f"Image path identified", artifact=img_path)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK) await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.WEAK
)
pr_comment = self._prepare_pr_answer() pr_comment = self._prepare_pr_answer()
get_logger().debug(f"PR output", artifact=pr_comment) get_logger().debug(f"PR output", artifact=pr_comment)
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_questions.enable_help_text: if (
self.git_provider.is_supported("gfm_markdown")
and get_settings().pr_questions.enable_help_text
):
pr_comment += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n" pr_comment += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
pr_comment += HelpMessage.get_ask_usage_guide() pr_comment += HelpMessage.get_ask_usage_guide()
pr_comment += "\n</details>\n" pr_comment += "\n</details>\n"
@ -85,7 +99,9 @@ class PRQuestions:
# /ask question ... > ![image](img_path) # /ask question ... > ![image](img_path)
img_path = self.question_str.split('![image]')[1].strip().strip('()') img_path = self.question_str.split('![image]')[1].strip().strip('()')
self.vars['img_path'] = img_path self.vars['img_path'] = img_path
elif 'https://' in self.question_str and ('.png' in self.question_str or 'jpg' in self.question_str): # direct image link elif 'https://' in self.question_str and (
'.png' in self.question_str or 'jpg' in self.question_str
): # direct image link
# include https:// in the image path # include https:// in the image path
img_path = 'https://' + self.question_str.split('https://')[1] img_path = 'https://' + self.question_str.split('https://')[1]
self.vars['img_path'] = img_path self.vars['img_path'] = img_path
@ -104,16 +120,28 @@ class PRQuestions:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables) get_settings().pr_questions_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_questions_prompt.user
).render(variables)
if 'img_path' in variables: if 'img_path' in variables:
img_path = self.vars['img_path'] img_path = self.vars['img_path']
response, finish_reason = await (self.ai_handler.chat_completion response, finish_reason = await self.ai_handler.chat_completion(
(model=model, temperature=get_settings().config.temperature, model=model,
system=system_prompt, user=user_prompt, img_path=img_path)) temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
img_path=img_path,
)
else: else:
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt) model=model,
temperature=get_settings().config.temperature,
system=system_prompt,
user=user_prompt,
)
return response return response
def _prepare_pr_answer(self) -> str: def _prepare_pr_answer(self) -> str:
@ -123,9 +151,13 @@ class PRQuestions:
if model_answer_sanitized.startswith("/"): if model_answer_sanitized.startswith("/"):
model_answer_sanitized = " " + model_answer_sanitized model_answer_sanitized = " " + model_answer_sanitized
if model_answer_sanitized != model_answer: if model_answer_sanitized != model_answer:
get_logger().debug(f"Sanitized model answer", get_logger().debug(
artifact={"model_answer": model_answer, "sanitized_answer": model_answer_sanitized}) f"Sanitized model answer",
artifact={
"model_answer": model_answer,
"sanitized_answer": model_answer_sanitized,
},
)
answer_str = f"### **Ask**❓\n{self.question_str}\n\n" answer_str = f"### **Ask**❓\n{self.question_str}\n\n"
answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n" answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n"

View File

@ -7,21 +7,29 @@ from jinja2 import Environment, StrictUndefined
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files, from utils.pr_agent.algo.pr_processing import (
add_ai_metadata_to_diff_files,
get_pr_diff, get_pr_diff,
retry_with_fallback_models) retry_with_fallback_models,
)
from utils.pr_agent.algo.token_handler import TokenHandler from utils.pr_agent.algo.token_handler import TokenHandler
from utils.pr_agent.algo.utils import (ModelType, PRReviewHeader, from utils.pr_agent.algo.utils import (
convert_to_markdown_v2, github_action_output, ModelType,
load_yaml, show_relevant_configurations) PRReviewHeader,
convert_to_markdown_v2,
github_action_output,
load_yaml,
show_relevant_configurations,
)
from utils.pr_agent.config_loader import get_settings from utils.pr_agent.config_loader import get_settings
from utils.pr_agent.git_providers import (get_git_provider_with_context) from utils.pr_agent.git_providers import get_git_provider_with_context
from utils.pr_agent.git_providers.git_provider import (IncrementalPR, from utils.pr_agent.git_providers.git_provider import (
get_main_pr_language) IncrementalPR,
get_main_pr_language,
)
from utils.pr_agent.log import get_logger from utils.pr_agent.log import get_logger
from utils.pr_agent.servers.help import HelpMessage from utils.pr_agent.servers.help import HelpMessage
from utils.pr_agent.tools.ticket_pr_compliance_check import ( from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
extract_and_cache_pr_tickets)
class PRReviewer: class PRReviewer:
@ -29,8 +37,14 @@ class PRReviewer:
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, def __init__(
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): self,
pr_url: str,
is_answer: bool = False,
is_auto: bool = False,
args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
@ -55,16 +69,23 @@ class PRReviewer:
self.is_auto = is_auto self.is_auto = is_auto
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") raise Exception(
f"Answer mode is not supported for {get_settings().config.git_provider} for now"
)
self.ai_handler = ai_handler() self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
answer_str, question_str = self._get_user_answers() answer_str, question_str = self._get_user_answers()
self.pr_description, self.pr_description_files = ( (
self.git_provider.get_pr_description(split_changes_walkthrough=True)) self.pr_description,
if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and self.pr_description_files,
get_settings().get("config.enable_ai_metadata", False)): ) = self.git_provider.get_pr_description(split_changes_walkthrough=True)
if (
self.pr_description_files
and get_settings().get("config.is_auto_command", False)
and get_settings().get("config.enable_ai_metadata", False)
):
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files) add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
get_logger().debug(f"AI metadata added to the this command") get_logger().debug(f"AI metadata added to the this command")
else: else:
@ -91,7 +112,9 @@ class PRReviewer:
"enable_custom_labels": get_settings().config.enable_custom_labels, "enable_custom_labels": get_settings().config.enable_custom_labels,
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
"related_tickets": get_settings().get('related_tickets', []), "related_tickets": get_settings().get('related_tickets', []),
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False), 'duplicate_prompt_examples': get_settings().config.get(
'duplicate_prompt_examples', False
),
"date": datetime.datetime.now().strftime('%Y-%m-%d'), "date": datetime.datetime.now().strftime('%Y-%m-%d'),
} }
@ -99,7 +122,7 @@ class PRReviewer:
self.git_provider.pr, self.git_provider.pr,
self.vars, self.vars,
get_settings().pr_review_prompt.system, get_settings().pr_review_prompt.system,
get_settings().pr_review_prompt.user get_settings().pr_review_prompt.user,
) )
def parse_incremental(self, args: List[str]): def parse_incremental(self, args: List[str]):
@ -117,7 +140,10 @@ class PRReviewer:
get_logger().info(f"PR has no files: {self.pr_url}, skipping review") get_logger().info(f"PR has no files: {self.pr_url}, skipping review")
return None return None
if self.incremental.is_incremental and not self._can_run_incremental_review(): if (
self.incremental.is_incremental
and not self._can_run_incremental_review()
):
return None return None
# if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve': # if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve':
@ -126,27 +152,41 @@ class PRReviewer:
# return None # return None
get_logger().info(f'Reviewing PR: {self.pr_url} ...') get_logger().info(f'Reviewing PR: {self.pr_url} ...')
relevant_configs = {'pr_reviewer': dict(get_settings().pr_reviewer), relevant_configs = {
'config': dict(get_settings().config)} 'pr_reviewer': dict(get_settings().pr_reviewer),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs) get_logger().debug("Relevant configs", artifacts=relevant_configs)
# ticket extraction if exists # ticket extraction if exists
await extract_and_cache_pr_tickets(self.git_provider, self.vars) await extract_and_cache_pr_tickets(self.git_provider, self.vars)
if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set: if (
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files") self.incremental.is_incremental
and hasattr(self.git_provider, "unreviewed_files_set")
and not self.git_provider.unreviewed_files_set
):
get_logger().info(
f"Incremental review is enabled for {self.pr_url} but there are no new files"
)
previous_review_url = "" previous_review_url = ""
if hasattr(self.git_provider, "previous_review"): if hasattr(self.git_provider, "previous_review"):
previous_review_url = self.git_provider.previous_review.html_url previous_review_url = self.git_provider.previous_review.html_url
if get_settings().config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment(f"Incremental Review Skipped\n" self.git_provider.publish_comment(
f"No files were changed since the [previous PR Review]({previous_review_url})") f"Incremental Review Skipped\n"
f"No files were changed since the [previous PR Review]({previous_review_url})"
)
return None return None
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False): if get_settings().config.publish_output and not get_settings().config.get(
'is_auto_command', False
):
self.git_provider.publish_comment("准备评审中...", is_temporary=True) self.git_provider.publish_comment("准备评审中...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.REGULAR
)
if not self.prediction: if not self.prediction:
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
return None return None
@ -156,12 +196,19 @@ class PRReviewer:
if get_settings().config.publish_output: if get_settings().config.publish_output:
# publish the review # publish the review
if get_settings().pr_reviewer.persistent_comment and not self.incremental.is_incremental: if (
final_update_message = get_settings().pr_reviewer.final_update_message get_settings().pr_reviewer.persistent_comment
self.git_provider.publish_persistent_comment(pr_review, and not self.incremental.is_incremental
):
final_update_message = (
get_settings().pr_reviewer.final_update_message
)
self.git_provider.publish_persistent_comment(
pr_review,
initial_header=f"{PRReviewHeader.REGULAR.value} 🔍", initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
update_header=True, update_header=True,
final_update_message=final_update_message, ) final_update_message=final_update_message,
)
else: else:
self.git_provider.publish_comment(pr_review) self.git_provider.publish_comment(pr_review)
@ -174,11 +221,13 @@ class PRReviewer:
get_logger().error(f"Failed to review PR: {e}") get_logger().error(f"Failed to review PR: {e}")
async def _prepare_prediction(self, model: str) -> None: async def _prepare_prediction(self, model: str) -> None:
self.patches_diff = get_pr_diff(self.git_provider, self.patches_diff = get_pr_diff(
self.git_provider,
self.token_handler, self.token_handler,
model, model,
add_line_numbers_to_hunks=True, add_line_numbers_to_hunks=True,
disable_extra_lines=False,) disable_extra_lines=False,
)
if self.patches_diff: if self.patches_diff:
get_logger().debug(f"PR diff", diff=self.patches_diff) get_logger().debug(f"PR diff", diff=self.patches_diff)
@ -201,14 +250,18 @@ class PRReviewer:
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables) get_settings().pr_review_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_review_prompt.user
).render(variables)
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, model=model,
temperature=get_settings().config.temperature, temperature=get_settings().config.temperature,
system=system_prompt, system=system_prompt,
user=user_prompt user=user_prompt,
) )
return response return response
@ -220,10 +273,20 @@ class PRReviewer:
""" """
first_key = 'review' first_key = 'review'
last_key = 'security_concerns' last_key = 'security_concerns'
data = load_yaml(self.prediction.strip(), data = load_yaml(
keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:", self.prediction.strip(),
"relevant_file:", "relevant_line:", "suggestion:"], keys_fix_yaml=[
first_key=first_key, last_key=last_key) "ticket_compliance_check",
"estimated_effort_to_review_[1-5]:",
"security_concerns:",
"key_issues_to_review:",
"relevant_file:",
"relevant_line:",
"suggestion:",
],
first_key=first_key,
last_key=last_key,
)
github_action_output(data, 'review') github_action_output(data, 'review')
# move data['review'] 'key_issues_to_review' key to the end of the dictionary # move data['review'] 'key_issues_to_review' key to the end of the dictionary
@ -234,24 +297,38 @@ class PRReviewer:
incremental_review_markdown_text = None incremental_review_markdown_text = None
# Add incremental review section # Add incremental review section
if self.incremental.is_incremental: if self.incremental.is_incremental:
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \ last_commit_url = (
f"{self.git_provider.get_pr_url()}/commits/"
f"{self.git_provider.incremental.first_new_commit_sha}" f"{self.git_provider.incremental.first_new_commit_sha}"
)
incremental_review_markdown_text = f"Starting from commit {last_commit_url}" incremental_review_markdown_text = f"Starting from commit {last_commit_url}"
markdown_text = convert_to_markdown_v2(data, self.git_provider.is_supported("gfm_markdown"), markdown_text = convert_to_markdown_v2(
data,
self.git_provider.is_supported("gfm_markdown"),
incremental_review_markdown_text, incremental_review_markdown_text,
git_provider=self.git_provider, git_provider=self.git_provider,
files=self.git_provider.get_diff_files()) files=self.git_provider.get_diff_files(),
)
# Add help text if gfm_markdown is supported # Add help text if gfm_markdown is supported
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_reviewer.enable_help_text: if (
self.git_provider.is_supported("gfm_markdown")
and get_settings().pr_reviewer.enable_help_text
):
markdown_text += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n" markdown_text += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
markdown_text += HelpMessage.get_review_usage_guide() markdown_text += HelpMessage.get_review_usage_guide()
markdown_text += "\n</details>\n" markdown_text += "\n</details>\n"
# Output the relevant configurations if enabled # Output the relevant configurations if enabled
if get_settings().get('config', {}).get('output_relevant_configurations', False): if (
markdown_text += show_relevant_configurations(relevant_section='pr_reviewer') get_settings()
.get('config', {})
.get('output_relevant_configurations', False)
):
markdown_text += show_relevant_configurations(
relevant_section='pr_reviewer'
)
# Add custom labels from the review prediction (effort, security) # Add custom labels from the review prediction (effort, security)
self.set_review_labels(data) self.set_review_labels(data)
@ -306,34 +383,50 @@ class PRReviewer:
if comment: if comment:
self.git_provider.remove_comment(comment) self.git_provider.remove_comment(comment)
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to remove previous review comment, error: {e}") get_logger().exception(
f"Failed to remove previous review comment, error: {e}"
)
def _can_run_incremental_review(self) -> bool: def _can_run_incremental_review(self) -> bool:
"""Checks if we can run incremental review according the various configurations and previous review""" """Checks if we can run incremental review according the various configurations and previous review"""
# checking if running is auto mode but there are no new commits # checking if running is auto mode but there are no new commits
if self.is_auto and not self.incremental.first_new_commit_sha: if self.is_auto and not self.incremental.first_new_commit_sha:
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new commits") get_logger().info(
f"Incremental review is enabled for {self.pr_url} but there are no new commits"
)
return False return False
if not hasattr(self.git_provider, "get_incremental_commits"): if not hasattr(self.git_provider, "get_incremental_commits"):
get_logger().info(f"Incremental review is not supported for {get_settings().config.git_provider}") get_logger().info(
f"Incremental review is not supported for {get_settings().config.git_provider}"
)
return False return False
# checking if there are enough commits to start the review # checking if there are enough commits to start the review
num_new_commits = len(self.incremental.commits_range) num_new_commits = len(self.incremental.commits_range)
num_commits_threshold = get_settings().pr_reviewer.minimal_commits_for_incremental_review num_commits_threshold = (
get_settings().pr_reviewer.minimal_commits_for_incremental_review
)
not_enough_commits = num_new_commits < num_commits_threshold not_enough_commits = num_new_commits < num_commits_threshold
# checking if the commits are not too recent to start the review # checking if the commits are not too recent to start the review
recent_commits_threshold = datetime.datetime.now() - datetime.timedelta( recent_commits_threshold = datetime.datetime.now() - datetime.timedelta(
minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review
) )
last_seen_commit_date = ( last_seen_commit_date = (
self.incremental.last_seen_commit.commit.author.date if self.incremental.last_seen_commit else None self.incremental.last_seen_commit.commit.author.date
if self.incremental.last_seen_commit
else None
) )
all_commits_too_recent = ( all_commits_too_recent = (
last_seen_commit_date > recent_commits_threshold if self.incremental.last_seen_commit else False last_seen_commit_date > recent_commits_threshold
if self.incremental.last_seen_commit
else False
) )
# check all the thresholds or just one to start the review # check all the thresholds or just one to start the review
condition = any if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review else all condition = (
any
if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review
else all
)
if condition((not_enough_commits, all_commits_too_recent)): if condition((not_enough_commits, all_commits_too_recent)):
get_logger().info( get_logger().info(
f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:" f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:"
@ -348,31 +441,55 @@ class PRReviewer:
return return
if not get_settings().pr_reviewer.require_estimate_effort_to_review: if not get_settings().pr_reviewer.require_estimate_effort_to_review:
get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output get_settings().pr_reviewer.enable_review_labels_effort = (
False # we did not generate this output
)
if not get_settings().pr_reviewer.require_security_review: if not get_settings().pr_reviewer.require_security_review:
get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output get_settings().pr_reviewer.enable_review_labels_security = (
False # we did not generate this output
)
if (get_settings().pr_reviewer.enable_review_labels_security or if (
get_settings().pr_reviewer.enable_review_labels_effort): get_settings().pr_reviewer.enable_review_labels_security
or get_settings().pr_reviewer.enable_review_labels_effort
):
try: try:
review_labels = [] review_labels = []
if get_settings().pr_reviewer.enable_review_labels_effort: if get_settings().pr_reviewer.enable_review_labels_effort:
estimated_effort = data['review']['estimated_effort_to_review_[1-5]'] estimated_effort = data['review'][
'estimated_effort_to_review_[1-5]'
]
estimated_effort_number = 0 estimated_effort_number = 0
if isinstance(estimated_effort, str): if isinstance(estimated_effort, str):
try: try:
estimated_effort_number = int(estimated_effort.split(',')[0]) estimated_effort_number = int(
estimated_effort.split(',')[0]
)
except ValueError: except ValueError:
get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}") get_logger().warning(
f"Invalid estimated_effort value: {estimated_effort}"
)
elif isinstance(estimated_effort, int): elif isinstance(estimated_effort, int):
estimated_effort_number = estimated_effort estimated_effort_number = estimated_effort
else: else:
get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}") get_logger().warning(
f"Unexpected type for estimated_effort: {type(estimated_effort)}"
)
if 1 <= estimated_effort_number <= 5: # 1, because ... if 1 <= estimated_effort_number <= 5: # 1, because ...
review_labels.append(f'Review effort {estimated_effort_number}/5') review_labels.append(
if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review: f'Review effort {estimated_effort_number}/5'
security_concerns = data['review']['security_concerns'] # yes, because ... )
security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower() if (
get_settings().pr_reviewer.enable_review_labels_security
and get_settings().pr_reviewer.require_security_review
):
security_concerns = data['review'][
'security_concerns'
] # yes, because ...
security_concerns_bool = (
'yes' in security_concerns.lower()
or 'true' in security_concerns.lower()
)
if security_concerns_bool: if security_concerns_bool:
review_labels.append('Possible security concern') review_labels.append('Possible security concern')
@ -381,17 +498,26 @@ class PRReviewer:
current_labels = [] current_labels = []
get_logger().debug(f"Current labels:\n{current_labels}") get_logger().debug(f"Current labels:\n{current_labels}")
if current_labels: if current_labels:
current_labels_filtered = [label for label in current_labels if current_labels_filtered = [
not label.lower().startswith('review effort') and not label.lower().startswith( label
'possible security concern')] for label in current_labels
if not label.lower().startswith('review effort')
and not label.lower().startswith('possible security concern')
]
else: else:
current_labels_filtered = [] current_labels_filtered = []
new_labels = review_labels + current_labels_filtered new_labels = review_labels + current_labels_filtered
if (current_labels or review_labels) and sorted(new_labels) != sorted(current_labels): if (current_labels or review_labels) and sorted(new_labels) != sorted(
get_logger().info(f"Setting review labels:\n{review_labels + current_labels_filtered}") current_labels
):
get_logger().info(
f"Setting review labels:\n{review_labels + current_labels_filtered}"
)
self.git_provider.publish_labels(new_labels) self.git_provider.publish_labels(new_labels)
else: else:
get_logger().info(f"Review labels are already set:\n{review_labels + current_labels_filtered}") get_logger().info(
f"Review labels are already set:\n{review_labels + current_labels_filtered}"
)
except Exception as e: except Exception as e:
get_logger().error(f"Failed to set review labels, error: {e}") get_logger().error(f"Failed to set review labels, error: {e}")
@ -406,5 +532,7 @@ class PRReviewer:
self.git_provider.publish_comment("自动批准 PR") self.git_provider.publish_comment("自动批准 PR")
else: else:
get_logger().info("Auto-approval option is disabled") get_logger().info("Auto-approval option is disabled")
self.git_provider.publish_comment("PR-Agent 的自动批准选项已禁用. " self.git_provider.publish_comment(
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)") "PR-Agent 的自动批准选项已禁用. "
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)"
)

View File

@ -24,12 +24,16 @@ class PRSimilarIssue:
self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan
self.issue_url = issue_url self.issue_url = issue_url
self.git_provider = get_git_provider()() self.git_provider = get_git_provider()()
repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1]) repo_name, issue_number = self.git_provider._parse_issue_url(
issue_url.split('=')[-1]
)
self.git_provider.repo = repo_name self.git_provider.repo = repo_name
self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name) self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
self.token_handler = TokenHandler() self.token_handler = TokenHandler()
repo_obj = self.git_provider.repo_obj repo_obj = self.git_provider.repo_obj
repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-') repo_name_for_index = self.repo_name_for_index = (
repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
)
index_name = self.index_name = "codium-ai-pr-agent-issues" index_name = self.index_name = "codium-ai-pr-agent-issues"
if get_settings().pr_similar_issue.vectordb == "pinecone": if get_settings().pr_similar_issue.vectordb == "pinecone":
@ -38,17 +42,30 @@ class PRSimilarIssue:
import pinecone import pinecone
from pinecone_datasets import Dataset, DatasetMetadata from pinecone_datasets import Dataset, DatasetMetadata
except: except:
raise Exception("Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb") raise Exception(
"Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb"
)
# assuming pinecone api key and environment are set in secrets file # assuming pinecone api key and environment are set in secrets file
try: try:
api_key = get_settings().pinecone.api_key api_key = get_settings().pinecone.api_key
environment = get_settings().pinecone.environment environment = get_settings().pinecone.environment
except Exception: except Exception:
if not self.cli_mode: if not self.cli_mode:
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1]) (
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number) repo_name,
issue_main.create_comment("Please set pinecone api key and environment in secrets file") original_issue_number,
raise Exception("Please set pinecone api key and environment in secrets file") ) = self.git_provider._parse_issue_url(
self.issue_url.split('=')[-1]
)
issue_main = self.git_provider.repo_obj.get_issue(
original_issue_number
)
issue_main.create_comment(
"Please set pinecone api key and environment in secrets file"
)
raise Exception(
"Please set pinecone api key and environment in secrets file"
)
# check if index exists, and if repo is already indexed # check if index exists, and if repo is already indexed
run_from_scratch = False run_from_scratch = False
@ -69,7 +86,9 @@ class PRSimilarIssue:
upsert = True upsert = True
else: else:
pinecone_index = pinecone.Index(index_name=index_name) pinecone_index = pinecone.Index(index_name=index_name)
res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict() res = pinecone_index.fetch(
[f"example_issue_{repo_name_for_index}"]
).to_dict()
if res["vectors"]: if res["vectors"]:
upsert = False upsert = False
@ -79,7 +98,9 @@ class PRSimilarIssue:
get_logger().info('Getting issues...') get_logger().info('Getting issues...')
issues = list(repo_obj.get_issues(state='all')) issues = list(repo_obj.get_issues(state='all'))
get_logger().info('Done') get_logger().info('Done')
self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert) self._update_index_with_issues(
issues, repo_name_for_index, upsert=upsert
)
else: # update index if needed else: # update index if needed
pinecone_index = pinecone.Index(index_name=index_name) pinecone_index = pinecone.Index(index_name=index_name)
issues_to_update = [] issues_to_update = []
@ -105,7 +126,9 @@ class PRSimilarIssue:
if issues_to_update: if issues_to_update:
get_logger().info(f'Updating index with {counter} new issues...') get_logger().info(f'Updating index with {counter} new issues...')
self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True) self._update_index_with_issues(
issues_to_update, repo_name_for_index, upsert=True
)
else: else:
get_logger().info('No new issues to update') get_logger().info('No new issues to update')
@ -133,7 +156,12 @@ class PRSimilarIssue:
ingest = True ingest = True
else: else:
self.table = self.db[index_name] self.table = self.db[index_name]
res = self.table.search().limit(len(self.table)).where(f"id='example_issue_{repo_name_for_index}'").to_list() res = (
self.table.search()
.limit(len(self.table))
.where(f"id='example_issue_{repo_name_for_index}'")
.to_list()
)
get_logger().info("result: ", res) get_logger().info("result: ", res)
if res[0].get("vector"): if res[0].get("vector"):
ingest = False ingest = False
@ -145,7 +173,9 @@ class PRSimilarIssue:
issues = list(repo_obj.get_issues(state='all')) issues = list(repo_obj.get_issues(state='all'))
get_logger().info('Done') get_logger().info('Done')
self._update_table_with_issues(issues, repo_name_for_index, ingest=ingest) self._update_table_with_issues(
issues, repo_name_for_index, ingest=ingest
)
else: # update table if needed else: # update table if needed
issues_to_update = [] issues_to_update = []
issues_paginated_list = repo_obj.get_issues(state='all') issues_paginated_list = repo_obj.get_issues(state='all')
@ -156,7 +186,12 @@ class PRSimilarIssue:
issue_str, comments, number = self._process_issue(issue) issue_str, comments, number = self._process_issue(issue)
issue_key = f"issue_{number}" issue_key = f"issue_{number}"
issue_id = issue_key + "." + "issue" issue_id = issue_key + "." + "issue"
res = self.table.search().limit(len(self.table)).where(f"id='{issue_id}'").to_list() res = (
self.table.search()
.limit(len(self.table))
.where(f"id='{issue_id}'")
.to_list()
)
is_new_issue = True is_new_issue = True
for r in res: for r in res:
if r['metadata']['repo'] == repo_name_for_index: if r['metadata']['repo'] == repo_name_for_index:
@ -170,14 +205,17 @@ class PRSimilarIssue:
if issues_to_update: if issues_to_update:
get_logger().info(f'Updating index with {counter} new issues...') get_logger().info(f'Updating index with {counter} new issues...')
self._update_table_with_issues(issues_to_update, repo_name_for_index, ingest=True) self._update_table_with_issues(
issues_to_update, repo_name_for_index, ingest=True
)
else: else:
get_logger().info('No new issues to update') get_logger().info('No new issues to update')
async def run(self): async def run(self):
get_logger().info('Getting issue...') get_logger().info('Getting issue...')
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1]) repo_name, original_issue_number = self.git_provider._parse_issue_url(
self.issue_url.split('=')[-1]
)
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number) issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
issue_str, comments, number = self._process_issue(issue_main) issue_str, comments, number = self._process_issue(issue_main)
openai.api_key = get_settings().openai.key openai.api_key = get_settings().openai.key
@ -193,10 +231,12 @@ class PRSimilarIssue:
if get_settings().pr_similar_issue.vectordb == "pinecone": if get_settings().pr_similar_issue.vectordb == "pinecone":
pinecone_index = pinecone.Index(index_name=self.index_name) pinecone_index = pinecone.Index(index_name=self.index_name)
res = pinecone_index.query(embeds[0], res = pinecone_index.query(
embeds[0],
top_k=5, top_k=5,
filter={"repo": self.repo_name_for_index}, filter={"repo": self.repo_name_for_index},
include_metadata=True).to_dict() include_metadata=True,
).to_dict()
for r in res['matches']: for r in res['matches']:
# skip example issue # skip example issue
@ -214,14 +254,20 @@ class PRSimilarIssue:
if issue_number not in relevant_issues_number_list: if issue_number not in relevant_issues_number_list:
relevant_issues_number_list.append(issue_number) relevant_issues_number_list.append(issue_number)
if 'comment' in r["id"]: if 'comment' in r["id"]:
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1])) relevant_comment_number_list.append(
int(r["id"].split('.')[1].split('_')[-1])
)
else: else:
relevant_comment_number_list.append(-1) relevant_comment_number_list.append(-1)
score_list.append(str("{:.2f}".format(r['score']))) score_list.append(str("{:.2f}".format(r['score'])))
get_logger().info('Done') get_logger().info('Done')
elif get_settings().pr_similar_issue.vectordb == "lancedb": elif get_settings().pr_similar_issue.vectordb == "lancedb":
res = self.table.search(embeds[0]).where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True).to_list() res = (
self.table.search(embeds[0])
.where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True)
.to_list()
)
for r in res: for r in res:
# skip example issue # skip example issue
@ -240,7 +286,9 @@ class PRSimilarIssue:
relevant_issues_number_list.append(issue_number) relevant_issues_number_list.append(issue_number)
if 'comment' in r["id"]: if 'comment' in r["id"]:
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1])) relevant_comment_number_list.append(
int(r["id"].split('.')[1].split('_')[-1])
)
else: else:
relevant_comment_number_list.append(-1) relevant_comment_number_list.append(-1)
score_list.append(str("{:.2f}".format(1 - r['_distance']))) score_list.append(str("{:.2f}".format(1 - r['_distance'])))
@ -254,8 +302,12 @@ class PRSimilarIssue:
title = issue.title title = issue.title
url = issue.html_url url = issue.html_url
if relevant_comment_number_list[i] != -1: if relevant_comment_number_list[i] != -1:
url = list(issue.get_comments())[relevant_comment_number_list[i]].html_url url = list(issue.get_comments())[
similar_issues_str += f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n" relevant_comment_number_list[i]
].html_url
similar_issues_str += (
f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
)
if get_settings().config.publish_output: if get_settings().config.publish_output:
response = issue_main.create_comment(similar_issues_str) response = issue_main.create_comment(similar_issues_str)
get_logger().info(similar_issues_str) get_logger().info(similar_issues_str)
@ -278,7 +330,7 @@ class PRSimilarIssue:
example_issue_record = Record( example_issue_record = Record(
id=f"example_issue_{repo_name_for_index}", id=f"example_issue_{repo_name_for_index}",
text="example_issue", text="example_issue",
metadata=Metadata(repo=repo_name_for_index) metadata=Metadata(repo=repo_name_for_index),
) )
corpus.append(example_issue_record) corpus.append(example_issue_record)
@ -298,15 +350,20 @@ class PRSimilarIssue:
issue_key = f"issue_{number}" issue_key = f"issue_{number}"
username = issue.user.login username = issue.user.login
created_at = str(issue.created_at) created_at = str(issue.created_at)
if len(issue_str) < 8000 or \ if len(issue_str) < 8000 or self.token_handler.count_tokens(
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first issue_str
) < get_max_tokens(
MODEL
): # fast reject first
issue_record = Record( issue_record = Record(
id=issue_key + "." + "issue", id=issue_key + "." + "issue",
text=issue_str, text=issue_str,
metadata=Metadata(repo=repo_name_for_index, metadata=Metadata(
repo=repo_name_for_index,
username=username, username=username,
created_at=created_at, created_at=created_at,
level=IssueLevel.ISSUE) level=IssueLevel.ISSUE,
),
) )
corpus.append(issue_record) corpus.append(issue_record)
if comments: if comments:
@ -316,15 +373,20 @@ class PRSimilarIssue:
if num_words_comment < 10 or not isinstance(comment_body, str): if num_words_comment < 10 or not isinstance(comment_body, str):
continue continue
if len(comment_body) < 8000 or \ if (
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]: len(comment_body) < 8000
or self.token_handler.count_tokens(comment_body)
< MAX_TOKENS[MODEL]
):
comment_record = Record( comment_record = Record(
id=issue_key + ".comment_" + str(j + 1), id=issue_key + ".comment_" + str(j + 1),
text=comment_body, text=comment_body,
metadata=Metadata(repo=repo_name_for_index, metadata=Metadata(
repo=repo_name_for_index,
username=username, # use issue username for all comments username=username, # use issue username for all comments
created_at=created_at, created_at=created_at,
level=IssueLevel.COMMENT) level=IssueLevel.COMMENT,
),
) )
corpus.append(comment_record) corpus.append(comment_record)
df = pd.DataFrame(corpus.dict()["documents"]) df = pd.DataFrame(corpus.dict()["documents"])
@ -355,7 +417,9 @@ class PRSimilarIssue:
environment = get_settings().pinecone.environment environment = get_settings().pinecone.environment
if not upsert: if not upsert:
get_logger().info('Creating index from scratch...') get_logger().info('Creating index from scratch...')
ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment) ds.to_pinecone_index(
self.index_name, api_key=api_key, environment=environment
)
time.sleep(15) # wait for pinecone to finalize indexing before querying time.sleep(15) # wait for pinecone to finalize indexing before querying
else: else:
get_logger().info('Upserting index...') get_logger().info('Upserting index...')
@ -374,7 +438,7 @@ class PRSimilarIssue:
example_issue_record = Record( example_issue_record = Record(
id=f"example_issue_{repo_name_for_index}", id=f"example_issue_{repo_name_for_index}",
text="example_issue", text="example_issue",
metadata=Metadata(repo=repo_name_for_index) metadata=Metadata(repo=repo_name_for_index),
) )
corpus.append(example_issue_record) corpus.append(example_issue_record)
@ -394,15 +458,20 @@ class PRSimilarIssue:
issue_key = f"issue_{number}" issue_key = f"issue_{number}"
username = issue.user.login username = issue.user.login
created_at = str(issue.created_at) created_at = str(issue.created_at)
if len(issue_str) < 8000 or \ if len(issue_str) < 8000 or self.token_handler.count_tokens(
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first issue_str
) < get_max_tokens(
MODEL
): # fast reject first
issue_record = Record( issue_record = Record(
id=issue_key + "." + "issue", id=issue_key + "." + "issue",
text=issue_str, text=issue_str,
metadata=Metadata(repo=repo_name_for_index, metadata=Metadata(
repo=repo_name_for_index,
username=username, username=username,
created_at=created_at, created_at=created_at,
level=IssueLevel.ISSUE) level=IssueLevel.ISSUE,
),
) )
corpus.append(issue_record) corpus.append(issue_record)
if comments: if comments:
@ -412,15 +481,20 @@ class PRSimilarIssue:
if num_words_comment < 10 or not isinstance(comment_body, str): if num_words_comment < 10 or not isinstance(comment_body, str):
continue continue
if len(comment_body) < 8000 or \ if (
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]: len(comment_body) < 8000
or self.token_handler.count_tokens(comment_body)
< MAX_TOKENS[MODEL]
):
comment_record = Record( comment_record = Record(
id=issue_key + ".comment_" + str(j + 1), id=issue_key + ".comment_" + str(j + 1),
text=comment_body, text=comment_body,
metadata=Metadata(repo=repo_name_for_index, metadata=Metadata(
repo=repo_name_for_index,
username=username, # use issue username for all comments username=username, # use issue username for all comments
created_at=created_at, created_at=created_at,
level=IssueLevel.COMMENT) level=IssueLevel.COMMENT,
),
) )
corpus.append(comment_record) corpus.append(comment_record)
df = pd.DataFrame(corpus.dict()["documents"]) df = pd.DataFrame(corpus.dict()["documents"])
@ -446,7 +520,9 @@ class PRSimilarIssue:
if not ingest: if not ingest:
get_logger().info('Creating table from scratch...') get_logger().info('Creating table from scratch...')
self.table = self.db.create_table(self.index_name, data=df, mode="overwrite") self.table = self.db.create_table(
self.index_name, data=df, mode="overwrite"
)
time.sleep(15) time.sleep(15)
else: else:
get_logger().info('Ingesting in Table...') get_logger().info('Ingesting in Table...')

View File

@ -20,13 +20,20 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog: class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): def __init__(
self,
pr_url: str,
cli_mode=False,
args=None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes self.commit_changelog = (
get_settings().pr_update_changelog.push_changelog_changes
)
self._get_changelog_file() # self.changelog_file_str self._get_changelog_file() # self.changelog_file_str
self.ai_handler = ai_handler() self.ai_handler = ai_handler()
@ -47,15 +54,19 @@ class PRUpdateChangelog:
"extra_instructions": get_settings().pr_update_changelog.extra_instructions, "extra_instructions": get_settings().pr_update_changelog.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
self.token_handler = TokenHandler(self.git_provider.pr, self.token_handler = TokenHandler(
self.git_provider.pr,
self.vars, self.vars,
get_settings().pr_update_changelog_prompt.system, get_settings().pr_update_changelog_prompt.system,
get_settings().pr_update_changelog_prompt.user) get_settings().pr_update_changelog_prompt.user,
)
async def run(self): async def run(self):
get_logger().info('Updating the changelog...') get_logger().info('Updating the changelog...')
relevant_configs = {'pr_update_changelog': dict(get_settings().pr_update_changelog), relevant_configs = {
'config': dict(get_settings().config)} 'pr_update_changelog': dict(get_settings().pr_update_changelog),
'config': dict(get_settings().config),
}
get_logger().debug("Relevant configs", artifacts=relevant_configs) get_logger().debug("Relevant configs", artifacts=relevant_configs)
# currently only GitHub is supported for pushing changelog changes # currently only GitHub is supported for pushing changelog changes
@ -74,13 +85,21 @@ class PRUpdateChangelog:
if get_settings().config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True) self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK) await retry_with_fallback_models(
self._prepare_prediction, model_type=ModelType.WEAK
)
new_file_content, answer = self._prepare_changelog_update() new_file_content, answer = self._prepare_changelog_update()
# Output the relevant configurations if enabled # Output the relevant configurations if enabled
if get_settings().get('config', {}).get('output_relevant_configurations', False): if (
answer += show_relevant_configurations(relevant_section='pr_update_changelog') get_settings()
.get('config', {})
.get('output_relevant_configurations', False)
):
answer += show_relevant_configurations(
relevant_section='pr_update_changelog'
)
get_logger().debug(f"PR output", artifact=answer) get_logger().debug(f"PR output", artifact=answer)
@ -89,7 +108,9 @@ class PRUpdateChangelog:
if self.commit_changelog: if self.commit_changelog:
self._push_changelog_update(new_file_content, answer) self._push_changelog_update(new_file_content, answer)
else: else:
self.git_provider.publish_comment(f"**Changelog updates:** 🔄\n\n{answer}") self.git_provider.publish_comment(
f"**Changelog updates:** 🔄\n\n{answer}"
)
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str):
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
@ -106,10 +127,18 @@ class PRUpdateChangelog:
if get_settings().pr_update_changelog.add_pr_link: if get_settings().pr_update_changelog.add_pr_link:
variables["pr_link"] = self.git_provider.get_pr_url() variables["pr_link"] = self.git_provider.get_pr_url()
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables) system_prompt = environment.from_string(
user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables) get_settings().pr_update_changelog_prompt.system
).render(variables)
user_prompt = environment.from_string(
get_settings().pr_update_changelog_prompt.user
).render(variables)
response, finish_reason = await self.ai_handler.chat_completion( response, finish_reason = await self.ai_handler.chat_completion(
model=model, system=system_prompt, user=user_prompt, temperature=get_settings().config.temperature) model=model,
system=system_prompt,
user=user_prompt,
temperature=get_settings().config.temperature,
)
# post-process the response # post-process the response
response = response.strip() response = response.strip()
@ -134,8 +163,10 @@ class PRUpdateChangelog:
new_file_content = answer new_file_content = answer
if not self.commit_changelog: if not self.commit_changelog:
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \ answer += (
"\n\n\n>to commit the new content to the CHANGELOG.md file, please type:"
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n" "\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
)
return new_file_content, answer return new_file_content, answer
@ -163,8 +194,7 @@ class PRUpdateChangelog:
self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}") self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}")
def _get_default_changelog(self): def _get_default_changelog(self):
example_changelog = \ example_changelog = """
"""
Example: Example:
## <current_date> ## <current_date>

View File

@ -10,11 +10,12 @@ GITHUB_TICKET_PATTERN = re.compile(
r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)' r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
) )
def find_jira_tickets(text): def find_jira_tickets(text):
# Regular expression patterns for JIRA tickets # Regular expression patterns for JIRA tickets
patterns = [ patterns = [
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123) r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b', # JIRA URL or just the ticket
] ]
tickets = set() tickets = set()
@ -32,7 +33,9 @@ def find_jira_tickets(text):
return list(tickets) return list(tickets)
def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url_html='https://github.com'): def extract_ticket_links_from_pr_description(
pr_description, repo_path, base_url_html='https://github.com'
):
""" """
Extract all ticket links from PR description Extract all ticket links from PR description
""" """
@ -46,19 +49,27 @@ def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url
github_tickets.add(match[0]) github_tickets.add(match[0])
elif match[1]: # Shorthand notation match: owner/repo#issue_number elif match[1]: # Shorthand notation match: owner/repo#issue_number
owner, repo, issue_number = match[2], match[3], match[4] owner, repo, issue_number = match[2], match[3], match[4]
github_tickets.add(f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}') github_tickets.add(
f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}'
)
else: # #123 format else: # #123 format
issue_number = match[5][1:] # remove # issue_number = match[5][1:] # remove #
if issue_number.isdigit() and len(issue_number) < 5 and repo_path: if issue_number.isdigit() and len(issue_number) < 5 and repo_path:
github_tickets.add(f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}') github_tickets.add(
f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}'
)
if len(github_tickets) > 3: if len(github_tickets) > 3:
get_logger().info(f"Too many tickets found in PR description: {len(github_tickets)}") get_logger().info(
f"Too many tickets found in PR description: {len(github_tickets)}"
)
# Limit the number of tickets to 3 # Limit the number of tickets to 3
github_tickets = set(list(github_tickets)[:3]) github_tickets = set(list(github_tickets)[:3])
except Exception as e: except Exception as e:
get_logger().error(f"Error extracting tickets error= {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()},
)
return list(github_tickets) return list(github_tickets)
@ -68,19 +79,26 @@ async def extract_tickets(git_provider):
try: try:
if isinstance(git_provider, GithubProvider): if isinstance(git_provider, GithubProvider):
user_description = git_provider.get_user_description() user_description = git_provider.get_user_description()
tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html) tickets = extract_ticket_links_from_pr_description(
user_description, git_provider.repo, git_provider.base_url_html
)
tickets_content = [] tickets_content = []
if tickets: if tickets:
for ticket in tickets: for ticket in tickets:
repo_name, original_issue_number = git_provider._parse_issue_url(ticket) repo_name, original_issue_number = git_provider._parse_issue_url(
ticket
)
try: try:
issue_main = git_provider.repo_obj.get_issue(original_issue_number) issue_main = git_provider.repo_obj.get_issue(
original_issue_number
)
except Exception as e: except Exception as e:
get_logger().error(f"Error getting main issue: {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Error getting main issue: {e}",
artifact={"traceback": traceback.format_exc()},
)
continue continue
issue_body_str = issue_main.body or "" issue_body_str = issue_main.body or ""
@ -93,47 +111,66 @@ async def extract_tickets(git_provider):
sub_issues = git_provider.fetch_sub_issues(ticket) sub_issues = git_provider.fetch_sub_issues(ticket)
for sub_issue_url in sub_issues: for sub_issue_url in sub_issues:
try: try:
sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url) (
sub_issue = git_provider.repo_obj.get_issue(sub_issue_number) sub_repo,
sub_issue_number,
) = git_provider._parse_issue_url(sub_issue_url)
sub_issue = git_provider.repo_obj.get_issue(
sub_issue_number
)
sub_body = sub_issue.body or "" sub_body = sub_issue.body or ""
if len(sub_body) > MAX_TICKET_CHARACTERS: if len(sub_body) > MAX_TICKET_CHARACTERS:
sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..." sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..."
sub_issues_content.append({ sub_issues_content.append(
{
'ticket_url': sub_issue_url, 'ticket_url': sub_issue_url,
'title': sub_issue.title, 'title': sub_issue.title,
'body': sub_body 'body': sub_body,
}) }
)
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}") get_logger().warning(
f"Failed to fetch sub-issue content for {sub_issue_url}: {e}"
)
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}") get_logger().warning(
f"Failed to fetch sub-issues for {ticket}: {e}"
)
# Extract labels # Extract labels
labels = [] labels = []
try: try:
for label in issue_main.labels: for label in issue_main.labels:
labels.append(label.name if hasattr(label, 'name') else label) labels.append(
label.name if hasattr(label, 'name') else label
)
except Exception as e: except Exception as e:
get_logger().error(f"Error extracting labels error= {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Error extracting labels error= {e}",
artifact={"traceback": traceback.format_exc()},
)
tickets_content.append({ tickets_content.append(
{
'ticket_id': issue_main.number, 'ticket_id': issue_main.number,
'ticket_url': ticket, 'ticket_url': ticket,
'title': issue_main.title, 'title': issue_main.title,
'body': issue_body_str, 'body': issue_body_str,
'labels': ", ".join(labels), 'labels': ", ".join(labels),
'sub_issues': sub_issues_content # Store sub-issues content 'sub_issues': sub_issues_content, # Store sub-issues content
}) }
)
return tickets_content return tickets_content
except Exception as e: except Exception as e:
get_logger().error(f"Error extracting tickets error= {e}", get_logger().error(
artifact={"traceback": traceback.format_exc()}) f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()},
)
async def extract_and_cache_pr_tickets(git_provider, vars): async def extract_and_cache_pr_tickets(git_provider, vars):
@ -154,8 +191,10 @@ async def extract_and_cache_pr_tickets(git_provider, vars):
related_tickets.append(ticket) related_tickets.append(ticket)
get_logger().info("Extracted tickets and sub-issues from PR description", get_logger().info(
artifact={"tickets": related_tickets}) "Extracted tickets and sub-issues from PR description",
artifact={"tickets": related_tickets},
)
vars['related_tickets'] = related_tickets vars['related_tickets'] = related_tickets
get_settings().set('related_tickets', related_tickets) get_settings().set('related_tickets', related_tickets)

13
config.ini Normal file
View File

@ -0,0 +1,13 @@
[BASE]
; 是否开启debug模式 0或1
DEBUG = 0
[DATABASE]
; 默认采用sqlite线上需替换为pg
DEFAULT = pg
; postgres配置
DB_NAME = pr_manager
DB_USER = admin
DB_PASSWORD = admin123456
DB_HOST = 110.40.30.95
DB_PORT = 5432

View File

@ -12,11 +12,18 @@ https://docs.djangoproject.com/en/5.1/ref/settings/
import os import os
import sys import sys
import configparser
from pathlib import Path from pathlib import Path
# Build paths inside the project like this: BASE_DIR / 'subdir'. # Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent BASE_DIR = Path(__file__).resolve().parent.parent
CONFIG_NAME = BASE_DIR / "config.ini"
# 加载配置文件: 开发可加载config.local.ini
_config = configparser.ConfigParser()
_config.read(CONFIG_NAME, encoding="utf-8")
sys.path.insert(0, os.path.join(BASE_DIR, "apps")) sys.path.insert(0, os.path.join(BASE_DIR, "apps"))
sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils")) sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
@ -27,7 +34,7 @@ sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76" SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76"
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = False DEBUG = bool(int(_config["BASE"].get("DEBUG", "1")))
ALLOWED_HOSTS = ["*"] ALLOWED_HOSTS = ["*"]
@ -44,7 +51,7 @@ INSTALLED_APPS = [
"django.contrib.messages", "django.contrib.messages",
"django.contrib.staticfiles", "django.contrib.staticfiles",
"public", "public",
"pr" "pr",
] ]
# 配置安全秘钥 # 配置安全秘钥
@ -68,8 +75,7 @@ ROOT_URLCONF = "pr_manager.urls"
TEMPLATES = [ TEMPLATES = [
{ {
"BACKEND": "django.template.backends.django.DjangoTemplates", "BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [BASE_DIR / 'templates'] "DIRS": [BASE_DIR / 'templates'],
,
"APP_DIRS": True, "APP_DIRS": True,
"OPTIONS": { "OPTIONS": {
"context_processors": [ "context_processors": [
@ -89,11 +95,21 @@ WSGI_APPLICATION = "pr_manager.wsgi.application"
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases # https://docs.djangoproject.com/en/5.1/ref/settings/#databases
DATABASES = { DATABASES = {
"default": { "pg": {
"ENGINE": "django.db.backends.postgresql",
"NAME": _config["DATABASE"].get("DB_NAME", "chat_ai_v2"),
"USER": _config["DATABASE"].get("DB_USER", "admin"),
"PASSWORD": _config["DATABASE"].get("DB_PASSWORD", "admin123456"),
"HOST": _config["DATABASE"].get("DB_HOST", "124.222.222.101"),
"PORT": int(_config["DATABASE"].get("DB_PORT", "5432")),
},
"sqlite": {
"ENGINE": "django.db.backends.sqlite3", "ENGINE": "django.db.backends.sqlite3",
"NAME": BASE_DIR / "db.sqlite3", "NAME": BASE_DIR / "db.sqlite3",
},
} }
}
DATABASES["default"] = DATABASES[_config["DATABASE"].get("DEFAULT", "sqlite")]
# Password validation # Password validation