代码优化,增强清晰度和可维护性。
This commit is contained in:
parent
1988a400c9
commit
de84796560
3
.gitignore
vendored
3
.gitignore
vendored
@ -13,4 +13,5 @@ docs/.cache/
|
||||
.qodo
|
||||
db.sqlite3
|
||||
#pr_agent/
|
||||
static/admin/
|
||||
static/admin/
|
||||
config.local.ini
|
||||
|
||||
2
Pipfile
2
Pipfile
@ -20,9 +20,9 @@ pygithub = "*"
|
||||
python-gitlab = "*"
|
||||
retry = "*"
|
||||
fastapi = "*"
|
||||
psycopg2-binary = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
||||
[requires]
|
||||
python_version = "3.12"
|
||||
|
||||
|
||||
106
Pipfile.lock
generated
106
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "420206f7faa4351eabc368a83deae9b7ed9e50b0975ac63a46d6367e9920848b"
|
||||
"sha256": "497c1ff8497659883faf8dcca407665df1b3a37f67720f64b139f9dec8202892"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
@ -169,20 +169,19 @@
|
||||
},
|
||||
"boto3": {
|
||||
"hashes": [
|
||||
"sha256:01015b38017876d79efd7273f35d9a4adfba505237159621365bed21b9b65eca",
|
||||
"sha256:03bd8c93b226f07d944fd6b022e11a307bff94ab6a21d51675d7e3ea81ee8424"
|
||||
"sha256:e58136d52d79425ce26c3c1578bf94d4b2e91ead55fed9f6950406ee9713e6af"
|
||||
],
|
||||
"index": "pip_conf_index_global",
|
||||
"markers": "python_version >= '3.8'",
|
||||
"version": "==1.37.0"
|
||||
"version": "==1.37.2"
|
||||
},
|
||||
"botocore": {
|
||||
"hashes": [
|
||||
"sha256:b129d091a8360b4152ab65327186bf4e250de827c4a9b7ddf40a72b1acf1f3c1",
|
||||
"sha256:d01661f38c0edac87424344cdf4169f3ab9bc1bf1b677c8b230d025eb66c54a3"
|
||||
"sha256:3f460f3c32cd6d747d5897a9cbde011bf1715abc7bf0a6ea6fdb0b812df63287",
|
||||
"sha256:5f59b966f3cd0c8055ef6f7c2600f7db5f8218071d992e5f95da3f9156d4370f"
|
||||
],
|
||||
"markers": "python_version >= '3.8'",
|
||||
"version": "==1.37.0"
|
||||
"version": "==1.37.2"
|
||||
},
|
||||
"certifi": {
|
||||
"hashes": [
|
||||
@ -460,12 +459,12 @@
|
||||
},
|
||||
"django-import-export": {
|
||||
"hashes": [
|
||||
"sha256:317842a64233025a277040129fb6792fc48fd39622c185b70bf8c18c393d708f",
|
||||
"sha256:ecb4e6cdb4790d69bce261f9cca1007ca19cb431bb5a950ba907898245c8817b"
|
||||
"sha256:5514d09636e84e823a42cd5e79292f70f20d6d2feed117a145f5b64a5b44f168",
|
||||
"sha256:bd3fe0aa15a2bce9de4be1a2f882e2c4539fdbfdfa16f2052c98dd7aec0f085c"
|
||||
],
|
||||
"index": "pip_conf_index_global",
|
||||
"markers": "python_version >= '3.9'",
|
||||
"version": "==4.3.6"
|
||||
"version": "==4.3.7"
|
||||
},
|
||||
"django-simpleui": {
|
||||
"hashes": [
|
||||
@ -794,12 +793,12 @@
|
||||
},
|
||||
"litellm": {
|
||||
"hashes": [
|
||||
"sha256:02df5865f98ea9734a4d27ac7c33aad9a45c4015403d5c0797d3292ade3c5cb5",
|
||||
"sha256:d241436ac0edf64ec57fb5686f8d84a25998a7e52213d9063adf87df8432701f"
|
||||
"sha256:eaab989c090ccc094b41c3fdf27d1df7f6fb25e091ab0ce48e0f3079f1e51ff5",
|
||||
"sha256:ff9137c008cdb421db32defb1fbd1ed546a95167de6d276c61b664582ed4ff60"
|
||||
],
|
||||
"index": "pip_conf_index_global",
|
||||
"markers": "python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'",
|
||||
"version": "==1.61.16"
|
||||
"version": "==1.61.17"
|
||||
},
|
||||
"loguru": {
|
||||
"hashes": [
|
||||
@ -1196,6 +1195,81 @@
|
||||
"markers": "python_version >= '3.6'",
|
||||
"version": "==7.0.0"
|
||||
},
|
||||
"psycopg2-binary": {
|
||||
"hashes": [
|
||||
"sha256:04392983d0bb89a8717772a193cfaac58871321e3ec69514e1c4e0d4957b5aff",
|
||||
"sha256:056470c3dc57904bbf63d6f534988bafc4e970ffd50f6271fc4ee7daad9498a5",
|
||||
"sha256:0ea8e3d0ae83564f2fc554955d327fa081d065c8ca5cc6d2abb643e2c9c1200f",
|
||||
"sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5",
|
||||
"sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0",
|
||||
"sha256:19721ac03892001ee8fdd11507e6a2e01f4e37014def96379411ca99d78aeb2c",
|
||||
"sha256:1a6784f0ce3fec4edc64e985865c17778514325074adf5ad8f80636cd029ef7c",
|
||||
"sha256:2286791ececda3a723d1910441c793be44625d86d1a4e79942751197f4d30341",
|
||||
"sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f",
|
||||
"sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7",
|
||||
"sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d",
|
||||
"sha256:270934a475a0e4b6925b5f804e3809dd5f90f8613621d062848dd82f9cd62007",
|
||||
"sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142",
|
||||
"sha256:2ad26b467a405c798aaa1458ba09d7e2b6e5f96b1ce0ac15d82fd9f95dc38a92",
|
||||
"sha256:2b3d2491d4d78b6b14f76881905c7a8a8abcf974aad4a8a0b065273a0ed7a2cb",
|
||||
"sha256:2ce3e21dc3437b1d960521eca599d57408a695a0d3c26797ea0f72e834c7ffe5",
|
||||
"sha256:30e34c4e97964805f715206c7b789d54a78b70f3ff19fbe590104b71c45600e5",
|
||||
"sha256:3216ccf953b3f267691c90c6fe742e45d890d8272326b4a8b20850a03d05b7b8",
|
||||
"sha256:32581b3020c72d7a421009ee1c6bf4a131ef5f0a968fab2e2de0c9d2bb4577f1",
|
||||
"sha256:35958ec9e46432d9076286dda67942ed6d968b9c3a6a2fd62b48939d1d78bf68",
|
||||
"sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73",
|
||||
"sha256:3c18f74eb4386bf35e92ab2354a12c17e5eb4d9798e4c0ad3a00783eae7cd9f1",
|
||||
"sha256:3c4745a90b78e51d9ba06e2088a2fe0c693ae19cc8cb051ccda44e8df8a6eb53",
|
||||
"sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d",
|
||||
"sha256:3e9c76f0ac6f92ecfc79516a8034a544926430f7b080ec5a0537bca389ee0906",
|
||||
"sha256:48b338f08d93e7be4ab2b5f1dbe69dc5e9ef07170fe1f86514422076d9c010d0",
|
||||
"sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2",
|
||||
"sha256:512d29bb12608891e349af6a0cccedce51677725a921c07dba6342beaf576f9a",
|
||||
"sha256:5a507320c58903967ef7384355a4da7ff3f28132d679aeb23572753cbf2ec10b",
|
||||
"sha256:5c370b1e4975df846b0277b4deba86419ca77dbc25047f535b0bb03d1a544d44",
|
||||
"sha256:6b269105e59ac96aba877c1707c600ae55711d9dcd3fc4b5012e4af68e30c648",
|
||||
"sha256:6d4fa1079cab9018f4d0bd2db307beaa612b0d13ba73b5c6304b9fe2fb441ff7",
|
||||
"sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f",
|
||||
"sha256:73aa0e31fa4bb82578f3a6c74a73c273367727de397a7a0f07bd83cbea696baa",
|
||||
"sha256:7559bce4b505762d737172556a4e6ea8a9998ecac1e39b5233465093e8cee697",
|
||||
"sha256:79625966e176dc97ddabc142351e0409e28acf4660b88d1cf6adb876d20c490d",
|
||||
"sha256:7a813c8bdbaaaab1f078014b9b0b13f5de757e2b5d9be6403639b298a04d218b",
|
||||
"sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526",
|
||||
"sha256:7f4152f8f76d2023aac16285576a9ecd2b11a9895373a1f10fd9db54b3ff06b4",
|
||||
"sha256:7f5d859928e635fa3ce3477704acee0f667b3a3d3e4bb109f2b18d4005f38287",
|
||||
"sha256:851485a42dbb0bdc1edcdabdb8557c09c9655dfa2ca0460ff210522e073e319e",
|
||||
"sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673",
|
||||
"sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0",
|
||||
"sha256:8aabf1c1a04584c168984ac678a668094d831f152859d06e055288fa515e4d30",
|
||||
"sha256:8aecc5e80c63f7459a1a2ab2c64df952051df196294d9f739933a9f6687e86b3",
|
||||
"sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e",
|
||||
"sha256:8de718c0e1c4b982a54b41779667242bc630b2197948405b7bd8ce16bcecac92",
|
||||
"sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a",
|
||||
"sha256:b5f86c56eeb91dc3135b3fd8a95dc7ae14c538a2f3ad77a19645cf55bab1799c",
|
||||
"sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8",
|
||||
"sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909",
|
||||
"sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47",
|
||||
"sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864",
|
||||
"sha256:d00924255d7fc916ef66e4bf22f354a940c67179ad3fd7067d7a0a9c84d2fbfc",
|
||||
"sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00",
|
||||
"sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb",
|
||||
"sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539",
|
||||
"sha256:e5720a5d25e3b99cd0dc5c8a440570469ff82659bb09431c1439b92caf184d3b",
|
||||
"sha256:e8b58f0a96e7a1e341fc894f62c1177a7c83febebb5ff9123b579418fdc8a481",
|
||||
"sha256:e984839e75e0b60cfe75e351db53d6db750b00de45644c5d1f7ee5d1f34a1ce5",
|
||||
"sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4",
|
||||
"sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64",
|
||||
"sha256:ecced182e935529727401b24d76634a357c71c9275b356efafd8a2a91ec07392",
|
||||
"sha256:ee0e8c683a7ff25d23b55b11161c2663d4b099770f6085ff0a20d4505778d6b4",
|
||||
"sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1",
|
||||
"sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1",
|
||||
"sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567",
|
||||
"sha256:ffe8ed017e4ed70f68b7b371d84b7d4a790368db9203dfc2d222febd3a9c8863"
|
||||
],
|
||||
"index": "pip_conf_index_global",
|
||||
"markers": "python_version >= '3.8'",
|
||||
"version": "==2.9.10"
|
||||
},
|
||||
"py": {
|
||||
"hashes": [
|
||||
"sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719",
|
||||
@ -1766,11 +1840,11 @@
|
||||
},
|
||||
"s3transfer": {
|
||||
"hashes": [
|
||||
"sha256:3b39185cb72f5acc77db1a58b6e25b977f28d20496b6e58d6813d75f464d632f",
|
||||
"sha256:be6ecb39fadd986ef1701097771f87e4d2f821f27f6071c872143884d2950fbc"
|
||||
"sha256:ca855bdeb885174b5ffa95b9913622459d4ad8e331fc98eb01e6d5eb6a30655d",
|
||||
"sha256:edae4977e3a122445660c7c114bba949f9d191bae3b34a096f18a1c8c354527a"
|
||||
],
|
||||
"markers": "python_version >= '3.8'",
|
||||
"version": "==0.11.2"
|
||||
"version": "==0.11.3"
|
||||
},
|
||||
"simplepro": {
|
||||
"hashes": [
|
||||
|
||||
@ -34,7 +34,13 @@ class GitConfigAdmin(AjaxAdmin):
|
||||
class ProjectConfigAdmin(AjaxAdmin):
|
||||
"""Admin配置"""
|
||||
|
||||
list_display = ["project_id", "project_name", "project_secret", "commands", "is_enable"]
|
||||
list_display = [
|
||||
"project_id",
|
||||
"project_name",
|
||||
"project_secret",
|
||||
"commands",
|
||||
"is_enable",
|
||||
]
|
||||
readonly_fields = ["create_by", "delete_at", "detail"]
|
||||
top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>'
|
||||
|
||||
|
||||
@ -16,4 +16,3 @@ class Command(BaseCommand):
|
||||
print("初始化AI配置已创建")
|
||||
else:
|
||||
print("初始化AI配置已存在")
|
||||
|
||||
|
||||
@ -44,9 +44,7 @@ class GitConfig(BaseModel):
|
||||
null=True, blank=True, max_length=16, verbose_name="Git名称"
|
||||
)
|
||||
git_type = fields.RadioField(
|
||||
choices=constant.GIT_TYPE,
|
||||
default=0,
|
||||
verbose_name="Git类型"
|
||||
choices=constant.GIT_TYPE, default=0, verbose_name="Git类型"
|
||||
)
|
||||
git_url = fields.CharField(
|
||||
null=True, blank=True, max_length=128, verbose_name="Git地址"
|
||||
@ -67,6 +65,7 @@ class ProjectConfig(BaseModel):
|
||||
"""
|
||||
项目配置表
|
||||
"""
|
||||
|
||||
git_config = fields.ForeignKey(
|
||||
GitConfig,
|
||||
null=True,
|
||||
@ -89,10 +88,7 @@ class ProjectConfig(BaseModel):
|
||||
max_length=256,
|
||||
verbose_name="默认命令",
|
||||
)
|
||||
is_enable = fields.SwitchField(
|
||||
default=True,
|
||||
verbose_name="是否启用"
|
||||
)
|
||||
is_enable = fields.SwitchField(default=True, verbose_name="是否启用")
|
||||
|
||||
class Meta:
|
||||
verbose_name = "项目配置"
|
||||
@ -106,6 +102,7 @@ class ProjectHistory(BaseModel):
|
||||
"""
|
||||
项目历史表
|
||||
"""
|
||||
|
||||
project = fields.ForeignKey(
|
||||
ProjectConfig,
|
||||
null=True,
|
||||
@ -128,9 +125,7 @@ class ProjectHistory(BaseModel):
|
||||
mr_title = fields.CharField(
|
||||
null=True, blank=True, max_length=256, verbose_name="MR标题"
|
||||
)
|
||||
source_data = models.JSONField(
|
||||
null=True, blank=True, verbose_name="源数据"
|
||||
)
|
||||
source_data = models.JSONField(null=True, blank=True, verbose_name="源数据")
|
||||
|
||||
class Meta:
|
||||
verbose_name = "项目历史"
|
||||
|
||||
@ -12,12 +12,7 @@ from utils import constant
|
||||
|
||||
|
||||
def load_project_config(
|
||||
git_url,
|
||||
access_token,
|
||||
project_secret,
|
||||
openai_api_base,
|
||||
openai_key,
|
||||
llm_model
|
||||
git_url, access_token, project_secret, openai_api_base, openai_key, llm_model
|
||||
):
|
||||
"""
|
||||
加载项目配置
|
||||
@ -36,12 +31,11 @@ def load_project_config(
|
||||
"secret": project_secret,
|
||||
"openai_api_base": openai_api_base,
|
||||
"openai_key": openai_key,
|
||||
"llm_model": llm_model
|
||||
"llm_model": llm_model,
|
||||
}
|
||||
|
||||
|
||||
class WebHookView(View):
|
||||
|
||||
@staticmethod
|
||||
def select_git_provider(git_type):
|
||||
"""
|
||||
@ -82,7 +76,9 @@ class WebHookView(View):
|
||||
project_config = provider.get_project_config(project_id=project_id)
|
||||
|
||||
# Token 校验
|
||||
provider.check_secret(request_headers=headers, project_secret=project_config.get("project_secret"))
|
||||
provider.check_secret(
|
||||
request_headers=headers, project_secret=project_config.get("project_secret")
|
||||
)
|
||||
|
||||
provider.get_merge_request(
|
||||
request_data=json_data,
|
||||
@ -91,11 +87,13 @@ class WebHookView(View):
|
||||
api_base=project_config.get("api_base"),
|
||||
api_key=project_config.get("api_key"),
|
||||
llm_model=project_config.get("llm_model"),
|
||||
project_commands=project_config.get("commands")
|
||||
project_commands=project_config.get("commands"),
|
||||
)
|
||||
|
||||
# 记录请求日志: 目前仅记录合并日志
|
||||
if json_data.get('object_kind') == 'merge_request':
|
||||
provider.save_pr_agent_log(request_data=json_data, project_id=project_config.get("project_id"))
|
||||
provider.save_pr_agent_log(
|
||||
request_data=json_data, project_id=project_config.get("project_id")
|
||||
)
|
||||
|
||||
return JsonResponse(status=200, data={"status": "ignored"})
|
||||
|
||||
@ -1,8 +1,4 @@
|
||||
GIT_TYPE = (
|
||||
(0, "gitlab"),
|
||||
(1, "github"),
|
||||
(2, "gitea")
|
||||
)
|
||||
GIT_TYPE = ((0, "gitlab"), (1, "github"), (2, "gitea"))
|
||||
|
||||
DEFAULT_COMMANDS = (
|
||||
("/review", "/review"),
|
||||
@ -10,11 +6,7 @@ DEFAULT_COMMANDS = (
|
||||
("/improve_code", "/improve_code"),
|
||||
)
|
||||
|
||||
UA_TYPE = {
|
||||
"GitLab": "gitlab",
|
||||
"GitHub": "github",
|
||||
"Go-http-client": "gitea"
|
||||
}
|
||||
UA_TYPE = {"GitLab": "gitlab", "GitHub": "github", "Go-http-client": "gitea"}
|
||||
|
||||
|
||||
def get_git_type_from_ua(ua_value):
|
||||
|
||||
@ -16,14 +16,14 @@ class GitProvider(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_merge_request(
|
||||
self,
|
||||
request_data,
|
||||
git_url,
|
||||
access_token,
|
||||
api_base,
|
||||
api_key,
|
||||
llm_model,
|
||||
project_commands
|
||||
self,
|
||||
request_data,
|
||||
git_url,
|
||||
access_token,
|
||||
api_base,
|
||||
api_key,
|
||||
llm_model,
|
||||
project_commands,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -33,7 +33,6 @@ class GitProvider(ABC):
|
||||
|
||||
|
||||
class GitLabProvider(GitProvider):
|
||||
|
||||
@staticmethod
|
||||
def check_secret(request_headers, project_secret):
|
||||
"""
|
||||
@ -79,18 +78,18 @@ class GitLabProvider(GitProvider):
|
||||
"access_token": git_config.access_token,
|
||||
"project_secret": project_config.project_secret,
|
||||
"commands": project_config.commands.split(","),
|
||||
"project_id": project_config.id
|
||||
"project_id": project_config.id,
|
||||
}
|
||||
|
||||
def get_merge_request(
|
||||
self,
|
||||
request_data,
|
||||
git_url,
|
||||
access_token,
|
||||
api_base,
|
||||
api_key,
|
||||
llm_model,
|
||||
project_commands,
|
||||
self,
|
||||
request_data,
|
||||
git_url,
|
||||
access_token,
|
||||
api_base,
|
||||
api_key,
|
||||
llm_model,
|
||||
project_commands,
|
||||
):
|
||||
"""
|
||||
实现GitLab Merge Request获取逻辑
|
||||
@ -124,7 +123,10 @@ class GitLabProvider(GitProvider):
|
||||
self.run_command(mr_url, project_commands)
|
||||
# 数据库留存
|
||||
return JsonResponse(status=200, data={"status": "review started"})
|
||||
return JsonResponse(status=400, data={"error": "Merge request URL not found or action not open"})
|
||||
return JsonResponse(
|
||||
status=400,
|
||||
data={"error": "Merge request URL not found or action not open"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_pr_agent_log(request_data, project_id):
|
||||
@ -134,13 +136,19 @@ class GitLabProvider(GitProvider):
|
||||
:param project_id:
|
||||
:return:
|
||||
"""
|
||||
if request_data.get('object_attributes', {}).get("source_branch") and request_data.get('object_attributes', {}).get("target_branch"):
|
||||
if request_data.get('object_attributes', {}).get(
|
||||
"source_branch"
|
||||
) and request_data.get('object_attributes', {}).get("target_branch"):
|
||||
models.ProjectHistory.objects.create(
|
||||
project_id=project_id,
|
||||
project_url=request_data.get("project", {}).get("web_url"),
|
||||
mr_url=request_data.get('object_attributes', {}).get("url"),
|
||||
source_branch=request_data.get('object_attributes', {}).get("source_branch"),
|
||||
target_branch=request_data.get('object_attributes', {}).get("target_branch"),
|
||||
source_branch=request_data.get('object_attributes', {}).get(
|
||||
"source_branch"
|
||||
),
|
||||
target_branch=request_data.get('object_attributes', {}).get(
|
||||
"target_branch"
|
||||
),
|
||||
mr_title=request_data.get('object_attributes', {}).get("title"),
|
||||
source_data=request_data,
|
||||
)
|
||||
|
||||
@ -80,14 +80,20 @@ class PRAgent:
|
||||
if action == "answer":
|
||||
if notify:
|
||||
notify()
|
||||
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
|
||||
await PRReviewer(
|
||||
pr_url, is_answer=True, args=args, ai_handler=self.ai_handler
|
||||
).run()
|
||||
elif action == "auto_review":
|
||||
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
|
||||
await PRReviewer(
|
||||
pr_url, is_auto=True, args=args, ai_handler=self.ai_handler
|
||||
).run()
|
||||
elif action in command2class:
|
||||
if notify:
|
||||
notify()
|
||||
|
||||
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
|
||||
await command2class[action](
|
||||
pr_url, ai_handler=self.ai_handler, args=args
|
||||
).run()
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -88,7 +88,7 @@ USER_MESSAGE_ONLY_MODELS = [
|
||||
"deepseek/deepseek-reasoner",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview"
|
||||
"o1-preview",
|
||||
]
|
||||
|
||||
NO_SUPPORT_TEMPERATURE_MODELS = [
|
||||
@ -99,5 +99,5 @@ NO_SUPPORT_TEMPERATURE_MODELS = [
|
||||
"o1-2024-12-17",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31",
|
||||
"o1-preview"
|
||||
"o1-preview",
|
||||
]
|
||||
|
||||
@ -16,7 +16,14 @@ class BaseAiHandler(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
user: str,
|
||||
temperature: float = 0.2,
|
||||
img_path: str = None,
|
||||
):
|
||||
"""
|
||||
This method should be implemented to return a chat completion from the AI model.
|
||||
Args:
|
||||
|
||||
@ -34,9 +34,16 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
||||
"""
|
||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
|
||||
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||
@retry(
|
||||
exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||
tries=OPENAI_RETRIES,
|
||||
delay=2,
|
||||
backoff=2,
|
||||
jitter=(1, 3),
|
||||
)
|
||||
async def chat_completion(
|
||||
self, model: str, system: str, user: str, temperature: float = 0.2
|
||||
):
|
||||
try:
|
||||
messages = [SystemMessage(content=system), HumanMessage(content=user)]
|
||||
|
||||
@ -45,7 +52,7 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
||||
finish_reason = "completed"
|
||||
return resp.content, finish_reason
|
||||
|
||||
except (Exception) as e:
|
||||
except Exception as e:
|
||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||
raise e
|
||||
|
||||
@ -66,7 +73,10 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
||||
if openai_api_base is None or len(openai_api_base) == 0:
|
||||
return ChatOpenAI(openai_api_key=get_settings().openai.key)
|
||||
else:
|
||||
return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base)
|
||||
return ChatOpenAI(
|
||||
openai_api_key=get_settings().openai.key,
|
||||
openai_api_base=openai_api_base,
|
||||
)
|
||||
except AttributeError as e:
|
||||
if getattr(e, "name"):
|
||||
raise ValueError(f"OpenAI {e.name} is required") from e
|
||||
|
||||
@ -36,9 +36,14 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
elif 'OPENAI_API_KEY' not in os.environ:
|
||||
litellm.api_key = "dummy_key"
|
||||
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
|
||||
assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete"
|
||||
assert (
|
||||
get_settings().aws.AWS_SECRET_ACCESS_KEY
|
||||
and get_settings().aws.AWS_REGION_NAME
|
||||
), "AWS credentials are incomplete"
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
|
||||
os.environ[
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
] = get_settings().aws.AWS_SECRET_ACCESS_KEY
|
||||
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
|
||||
if get_settings().get("litellm.use_client"):
|
||||
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
|
||||
@ -73,14 +78,19 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
litellm.replicate_key = get_settings().replicate.key
|
||||
if get_settings().get("HUGGINGFACE.KEY", None):
|
||||
litellm.huggingface_key = get_settings().huggingface.key
|
||||
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
|
||||
if (
|
||||
get_settings().get("HUGGINGFACE.API_BASE", None)
|
||||
and 'huggingface' in get_settings().config.model
|
||||
):
|
||||
litellm.api_base = get_settings().huggingface.api_base
|
||||
self.api_base = get_settings().huggingface.api_base
|
||||
if get_settings().get("OLLAMA.API_BASE", None):
|
||||
litellm.api_base = get_settings().ollama.api_base
|
||||
self.api_base = get_settings().ollama.api_base
|
||||
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
|
||||
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
|
||||
self.repetition_penalty = float(
|
||||
get_settings().huggingface.repetition_penalty
|
||||
)
|
||||
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
|
||||
litellm.vertex_project = get_settings().vertexai.vertex_project
|
||||
litellm.vertex_location = get_settings().get(
|
||||
@ -89,7 +99,9 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
# Google AI Studio
|
||||
# SEE https://docs.litellm.ai/docs/providers/gemini
|
||||
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
|
||||
os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key
|
||||
os.environ[
|
||||
"GEMINI_API_KEY"
|
||||
] = get_settings().google_ai_studio.gemini_api_key
|
||||
|
||||
# Support deepseek models
|
||||
if get_settings().get("DEEPSEEK.KEY", None):
|
||||
@ -140,27 +152,35 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
git_provider = get_settings().config.git_provider
|
||||
|
||||
metadata = dict()
|
||||
callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback
|
||||
callbacks = (
|
||||
litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
+ litellm.service_callback
|
||||
)
|
||||
if "langfuse" in callbacks:
|
||||
metadata.update({
|
||||
"trace_name": command,
|
||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||
"trace_metadata": {
|
||||
"command": command,
|
||||
"pr_url": pr_url,
|
||||
},
|
||||
})
|
||||
if "langsmith" in callbacks:
|
||||
metadata.update({
|
||||
"run_name": command,
|
||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||
"extra": {
|
||||
"metadata": {
|
||||
metadata.update(
|
||||
{
|
||||
"trace_name": command,
|
||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||
"trace_metadata": {
|
||||
"command": command,
|
||||
"pr_url": pr_url,
|
||||
}
|
||||
},
|
||||
})
|
||||
},
|
||||
}
|
||||
)
|
||||
if "langsmith" in callbacks:
|
||||
metadata.update(
|
||||
{
|
||||
"run_name": command,
|
||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||
"extra": {
|
||||
"metadata": {
|
||||
"command": command,
|
||||
"pr_url": pr_url,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Adding the captured logs to the kwargs
|
||||
kwargs["metadata"] = metadata
|
||||
@ -175,10 +195,19 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError
|
||||
stop=stop_after_attempt(OPENAI_RETRIES)
|
||||
retry=retry_if_exception_type(
|
||||
(openai.APIError, openai.APIConnectionError, openai.APITimeoutError)
|
||||
), # No retry on RateLimitError
|
||||
stop=stop_after_attempt(OPENAI_RETRIES),
|
||||
)
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
user: str,
|
||||
temperature: float = 0.2,
|
||||
img_path: str = None,
|
||||
):
|
||||
try:
|
||||
resp, finish_reason = None, None
|
||||
deployment_id = self.deployment_id
|
||||
@ -187,8 +216,12 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
if 'claude' in model and not system:
|
||||
system = "No system prompt provided"
|
||||
get_logger().warning(
|
||||
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.")
|
||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error."
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
|
||||
if img_path:
|
||||
try:
|
||||
@ -201,14 +234,21 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error fetching image: {img_path}", e)
|
||||
return f"Error fetching image: {img_path}", "error"
|
||||
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
|
||||
{"type": "image_url", "image_url": {"url": img_path}}]
|
||||
messages[1]["content"] = [
|
||||
{"type": "text", "text": messages[1]["content"]},
|
||||
{"type": "image_url", "image_url": {"url": img_path}},
|
||||
]
|
||||
|
||||
# Currently, some models do not support a separate system and user prompts
|
||||
if model in self.user_message_only_models or get_settings().config.custom_reasoning_model:
|
||||
if (
|
||||
model in self.user_message_only_models
|
||||
or get_settings().config.custom_reasoning_model
|
||||
):
|
||||
user = f"{system}\n\n\n{user}"
|
||||
system = ""
|
||||
get_logger().info(f"Using model {model}, combining system and user prompts")
|
||||
get_logger().info(
|
||||
f"Using model {model}, combining system and user prompts"
|
||||
)
|
||||
messages = [{"role": "user", "content": user}]
|
||||
kwargs = {
|
||||
"model": model,
|
||||
@ -227,7 +267,10 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
}
|
||||
|
||||
# Add temperature only if model supports it
|
||||
if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model:
|
||||
if (
|
||||
model not in self.no_support_temperature_models
|
||||
and not get_settings().config.custom_reasoning_model
|
||||
):
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
if get_settings().litellm.get("enable_callbacks", False):
|
||||
@ -235,7 +278,9 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
|
||||
seed = get_settings().config.get("seed", -1)
|
||||
if temperature > 0 and seed >= 0:
|
||||
raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0")
|
||||
raise ValueError(
|
||||
f"Seed ({seed}) is not supported with temperature ({temperature}) > 0"
|
||||
)
|
||||
elif seed >= 0:
|
||||
get_logger().info(f"Using fixed seed of {seed}")
|
||||
kwargs["seed"] = seed
|
||||
@ -253,10 +298,10 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
except (openai.APIError, openai.APITimeoutError) as e:
|
||||
get_logger().warning(f"Error during LLM inference: {e}")
|
||||
raise
|
||||
except (openai.RateLimitError) as e:
|
||||
except openai.RateLimitError as e:
|
||||
get_logger().error(f"Rate limit error during LLM inference: {e}")
|
||||
raise
|
||||
except (Exception) as e:
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Unknown error during LLM inference: {e}")
|
||||
raise openai.APIError from e
|
||||
if response is None or len(response["choices"]) == 0:
|
||||
@ -267,7 +312,9 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
get_logger().debug(f"\nAI response:\n{resp}")
|
||||
|
||||
# log the full response for debugging
|
||||
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
|
||||
response_log = self.prepare_logs(
|
||||
response, system, user, resp, finish_reason
|
||||
)
|
||||
get_logger().debug("Full_response", artifact=response_log)
|
||||
|
||||
# for CLI debugging
|
||||
|
||||
@ -37,13 +37,23 @@ class OpenAIHandler(BaseAiHandler):
|
||||
"""
|
||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||
|
||||
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||
@retry(
|
||||
exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||
tries=OPENAI_RETRIES,
|
||||
delay=2,
|
||||
backoff=2,
|
||||
jitter=(1, 3),
|
||||
)
|
||||
async def chat_completion(
|
||||
self, model: str, system: str, user: str, temperature: float = 0.2
|
||||
):
|
||||
try:
|
||||
get_logger().info("System: ", system)
|
||||
get_logger().info("User: ", user)
|
||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
client = AsyncOpenAI()
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model,
|
||||
@ -53,15 +63,21 @@ class OpenAIHandler(BaseAiHandler):
|
||||
resp = chat_completion.choices[0].message.content
|
||||
finish_reason = chat_completion.choices[0].finish_reason
|
||||
usage = chat_completion.usage
|
||||
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
|
||||
model=model, usage=usage)
|
||||
get_logger().info(
|
||||
"AI response",
|
||||
response=resp,
|
||||
messages=messages,
|
||||
finish_reason=finish_reason,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
return resp, finish_reason
|
||||
except (APIError, Timeout) as e:
|
||||
get_logger().error("Error during OpenAI inference: ", e)
|
||||
raise
|
||||
except (RateLimitError) as e:
|
||||
except RateLimitError as e:
|
||||
get_logger().error("Rate limit error during OpenAI inference: ", e)
|
||||
raise
|
||||
except (Exception) as e:
|
||||
except Exception as e:
|
||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||
raise
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from base64 import b64decode
|
||||
import hashlib
|
||||
|
||||
|
||||
class CliArgs:
|
||||
@staticmethod
|
||||
def validate_user_args(args: list) -> (bool, str):
|
||||
@ -23,12 +24,12 @@ class CliArgs:
|
||||
for arg in args:
|
||||
if arg.startswith('--'):
|
||||
arg_word = arg.lower()
|
||||
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key
|
||||
arg_word = arg_word.replace(
|
||||
'__', '.'
|
||||
) # replace double underscore with dot, e.g. --openai__key -> --openai.key
|
||||
for forbidden_arg_word in forbidden_cli_args:
|
||||
if forbidden_arg_word in arg_word:
|
||||
return False, forbidden_arg_word
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import re
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def filter_ignored(files, platform = 'github'):
|
||||
def filter_ignored(files, platform='github'):
|
||||
"""
|
||||
Filter out files that match the ignore patterns.
|
||||
"""
|
||||
@ -15,7 +15,9 @@ def filter_ignored(files, platform = 'github'):
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
glob_setting = get_settings().ignore.glob
|
||||
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
|
||||
if isinstance(
|
||||
glob_setting, str
|
||||
): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
|
||||
glob_setting = glob_setting.strip('[]').split(",")
|
||||
patterns += [fnmatch.translate(glob) for glob in glob_setting]
|
||||
|
||||
@ -31,7 +33,9 @@ def filter_ignored(files, platform = 'github'):
|
||||
if files and isinstance(files, list):
|
||||
for r in compiled_patterns:
|
||||
if platform == 'github':
|
||||
files = [f for f in files if (f.filename and not r.match(f.filename))]
|
||||
files = [
|
||||
f for f in files if (f.filename and not r.match(f.filename))
|
||||
]
|
||||
elif platform == 'bitbucket':
|
||||
# files = [f for f in files if (f.new.path and not r.match(f.new.path))]
|
||||
files_o = []
|
||||
@ -49,10 +53,18 @@ def filter_ignored(files, platform = 'github'):
|
||||
# files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))]
|
||||
files_o = []
|
||||
for f in files:
|
||||
if 'new_path' in f and f['new_path'] and not r.match(f['new_path']):
|
||||
if (
|
||||
'new_path' in f
|
||||
and f['new_path']
|
||||
and not r.match(f['new_path'])
|
||||
):
|
||||
files_o.append(f)
|
||||
continue
|
||||
if 'old_path' in f and f['old_path'] and not r.match(f['old_path']):
|
||||
if (
|
||||
'old_path' in f
|
||||
and f['old_path']
|
||||
and not r.match(f['old_path'])
|
||||
):
|
||||
files_o.append(f)
|
||||
continue
|
||||
files = files_o
|
||||
|
||||
@ -8,9 +8,18 @@ from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
|
||||
patch_extra_lines_after=0, filename: str = "") -> str:
|
||||
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
|
||||
def extend_patch(
|
||||
original_file_str,
|
||||
patch_str,
|
||||
patch_extra_lines_before=0,
|
||||
patch_extra_lines_after=0,
|
||||
filename: str = "",
|
||||
) -> str:
|
||||
if (
|
||||
not patch_str
|
||||
or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0)
|
||||
or not original_file_str
|
||||
):
|
||||
return patch_str
|
||||
|
||||
original_file_str = decode_if_bytes(original_file_str)
|
||||
@ -21,10 +30,17 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
|
||||
return patch_str
|
||||
|
||||
try:
|
||||
extended_patch_str = process_patch_lines(patch_str, original_file_str,
|
||||
patch_extra_lines_before, patch_extra_lines_after)
|
||||
extended_patch_str = process_patch_lines(
|
||||
patch_str,
|
||||
original_file_str,
|
||||
patch_extra_lines_before,
|
||||
patch_extra_lines_after,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().warning(
|
||||
f"Failed to extend patch: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
return patch_str
|
||||
|
||||
return extended_patch_str
|
||||
@ -48,13 +64,19 @@ def decode_if_bytes(original_file_str):
|
||||
def should_skip_patch(filename):
|
||||
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
|
||||
if patch_extension_skip_types and filename:
|
||||
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
|
||||
return any(
|
||||
filename.endswith(skip_type) for skip_type in patch_extension_skip_types
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
|
||||
def process_patch_lines(
|
||||
patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after
|
||||
):
|
||||
allow_dynamic_context = get_settings().config.allow_dynamic_context
|
||||
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
|
||||
patch_extra_lines_before_dynamic = (
|
||||
get_settings().config.max_extra_lines_before_dynamic_context
|
||||
)
|
||||
|
||||
original_lines = original_file_str.splitlines()
|
||||
len_original_lines = len(original_lines)
|
||||
@ -63,59 +85,122 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
|
||||
|
||||
is_valid_hunk = True
|
||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
try:
|
||||
for i,line in enumerate(patch_lines):
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
# identify hunk header
|
||||
if match:
|
||||
# finish processing previous hunk
|
||||
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
|
||||
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
|
||||
delta_lines = [
|
||||
f' {line}'
|
||||
for line in original_lines[
|
||||
start1
|
||||
+ size1
|
||||
- 1 : start1
|
||||
+ size1
|
||||
- 1
|
||||
+ patch_extra_lines_after
|
||||
]
|
||||
]
|
||||
extended_patch_lines.extend(delta_lines)
|
||||
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(
|
||||
match
|
||||
)
|
||||
|
||||
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1)
|
||||
is_valid_hunk = check_if_hunk_lines_matches_to_file(
|
||||
i, original_lines, patch_lines, start1
|
||||
)
|
||||
|
||||
if is_valid_hunk and (
|
||||
patch_extra_lines_before > 0 or patch_extra_lines_after > 0
|
||||
):
|
||||
|
||||
if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0):
|
||||
def _calc_context_limits(patch_lines_before):
|
||||
extended_start1 = max(1, start1 - patch_lines_before)
|
||||
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
|
||||
extended_size1 = (
|
||||
size1
|
||||
+ (start1 - extended_start1)
|
||||
+ patch_extra_lines_after
|
||||
)
|
||||
extended_start2 = max(1, start2 - patch_lines_before)
|
||||
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
|
||||
if extended_start1 - 1 + extended_size1 > len_original_lines:
|
||||
extended_size2 = (
|
||||
size2
|
||||
+ (start2 - extended_start2)
|
||||
+ patch_extra_lines_after
|
||||
)
|
||||
if (
|
||||
extended_start1 - 1 + extended_size1
|
||||
> len_original_lines
|
||||
):
|
||||
# we cannot extend beyond the original file
|
||||
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines
|
||||
delta_cap = (
|
||||
extended_start1
|
||||
- 1
|
||||
+ extended_size1
|
||||
- len_original_lines
|
||||
)
|
||||
extended_size1 = max(extended_size1 - delta_cap, size1)
|
||||
extended_size2 = max(extended_size2 - delta_cap, size2)
|
||||
return extended_start1, extended_size1, extended_start2, extended_size2
|
||||
return (
|
||||
extended_start1,
|
||||
extended_size1,
|
||||
extended_start2,
|
||||
extended_size2,
|
||||
)
|
||||
|
||||
if allow_dynamic_context:
|
||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||
_calc_context_limits(patch_extra_lines_before_dynamic)
|
||||
lines_before = original_lines[extended_start1 - 1:start1 - 1]
|
||||
(
|
||||
extended_start1,
|
||||
extended_size1,
|
||||
extended_start2,
|
||||
extended_size2,
|
||||
) = _calc_context_limits(patch_extra_lines_before_dynamic)
|
||||
lines_before = original_lines[
|
||||
extended_start1 - 1 : start1 - 1
|
||||
]
|
||||
found_header = False
|
||||
for i, line, in enumerate(lines_before):
|
||||
for (
|
||||
i,
|
||||
line,
|
||||
) in enumerate(lines_before):
|
||||
if section_header in line:
|
||||
found_header = True
|
||||
# Update start and size in one line each
|
||||
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i
|
||||
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i
|
||||
extended_start1, extended_start2 = (
|
||||
extended_start1 + i,
|
||||
extended_start2 + i,
|
||||
)
|
||||
extended_size1, extended_size2 = (
|
||||
extended_size1 - i,
|
||||
extended_size2 - i,
|
||||
)
|
||||
# get_logger().debug(f"Found section header in line {i} before the hunk")
|
||||
section_header = ''
|
||||
break
|
||||
if not found_header:
|
||||
# get_logger().debug(f"Section header not found in the extra lines before the hunk")
|
||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||
_calc_context_limits(patch_extra_lines_before)
|
||||
(
|
||||
extended_start1,
|
||||
extended_size1,
|
||||
extended_start2,
|
||||
extended_size2,
|
||||
) = _calc_context_limits(patch_extra_lines_before)
|
||||
else:
|
||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||
_calc_context_limits(patch_extra_lines_before)
|
||||
(
|
||||
extended_start1,
|
||||
extended_size1,
|
||||
extended_start2,
|
||||
extended_size2,
|
||||
) = _calc_context_limits(patch_extra_lines_before)
|
||||
|
||||
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
|
||||
delta_lines = [
|
||||
f' {line}'
|
||||
for line in original_lines[extended_start1 - 1 : start1 - 1]
|
||||
]
|
||||
|
||||
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
|
||||
if section_header and not allow_dynamic_context:
|
||||
@ -132,17 +217,23 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
|
||||
extended_patch_lines.append('')
|
||||
extended_patch_lines.append(
|
||||
f'@@ -{extended_start1},{extended_size1} '
|
||||
f'+{extended_start2},{extended_size2} @@ {section_header}')
|
||||
f'+{extended_start2},{extended_size2} @@ {section_header}'
|
||||
)
|
||||
extended_patch_lines.extend(delta_lines) # one to zero based
|
||||
continue
|
||||
extended_patch_lines.append(line)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().warning(
|
||||
f"Failed to extend patch: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
return patch_str
|
||||
|
||||
# finish processing last hunk
|
||||
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
|
||||
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
|
||||
delta_lines = original_lines[
|
||||
start1 + size1 - 1 : start1 + size1 - 1 + patch_extra_lines_after
|
||||
]
|
||||
# add space at the beginning of each extra line
|
||||
delta_lines = [f' {line}' for line in delta_lines]
|
||||
extended_patch_lines.extend(delta_lines)
|
||||
@ -158,11 +249,14 @@ def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
|
||||
"""
|
||||
is_valid_hunk = True
|
||||
try:
|
||||
if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file
|
||||
if (
|
||||
i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' '
|
||||
): # an existing line in the file
|
||||
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
|
||||
is_valid_hunk = False
|
||||
get_logger().error(
|
||||
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content")
|
||||
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content"
|
||||
)
|
||||
except:
|
||||
pass
|
||||
return is_valid_hunk
|
||||
@ -195,8 +289,7 @@ def omit_deletion_hunks(patch_lines) -> str:
|
||||
added_patched = []
|
||||
add_hunk = False
|
||||
inside_hunk = False
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
|
||||
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
|
||||
|
||||
for line in patch_lines:
|
||||
if line.startswith('@@'):
|
||||
@ -221,8 +314,13 @@ def omit_deletion_hunks(patch_lines) -> str:
|
||||
return '\n'.join(added_patched)
|
||||
|
||||
|
||||
def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str:
|
||||
def handle_patch_deletions(
|
||||
patch: str,
|
||||
original_file_content_str: str,
|
||||
new_file_content_str: str,
|
||||
file_name: str,
|
||||
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN,
|
||||
) -> str:
|
||||
"""
|
||||
Handle entire file or deletion patches.
|
||||
|
||||
@ -239,11 +337,13 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
str: The modified patch with deletion hunks omitted.
|
||||
|
||||
"""
|
||||
if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN):
|
||||
if not new_file_content_str and (
|
||||
edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN
|
||||
):
|
||||
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
||||
if get_settings().config.verbosity_level > 0:
|
||||
get_logger().info(f"Processing file: {file_name}, minimizing deletion file")
|
||||
patch = None # file was deleted
|
||||
patch = None # file was deleted
|
||||
else:
|
||||
patch_lines = patch.splitlines()
|
||||
patch_new = omit_deletion_hunks(patch_lines)
|
||||
@ -256,35 +356,35 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
||||
|
||||
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
||||
"""
|
||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
|
||||
the file.
|
||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
|
||||
the file.
|
||||
|
||||
Args:
|
||||
patch (str): The patch string to be converted.
|
||||
file: An object containing the filename of the file being patched.
|
||||
Args:
|
||||
patch (str): The patch string to be converted.
|
||||
file: An object containing the filename of the file being patched.
|
||||
|
||||
Returns:
|
||||
str: A string with line numbers for each hunk, indicating the new and old content of the file.
|
||||
Returns:
|
||||
str: A string with line numbers for each hunk, indicating the new and old content of the file.
|
||||
|
||||
example output:
|
||||
## src/file.ts
|
||||
__new hunk__
|
||||
881 line1
|
||||
882 line2
|
||||
883 line3
|
||||
887 + line4
|
||||
888 + line5
|
||||
889 line6
|
||||
890 line7
|
||||
...
|
||||
__old hunk__
|
||||
line1
|
||||
line2
|
||||
- line3
|
||||
- line4
|
||||
line5
|
||||
line6
|
||||
...
|
||||
example output:
|
||||
## src/file.ts
|
||||
__new hunk__
|
||||
881 line1
|
||||
882 line2
|
||||
883 line3
|
||||
887 + line4
|
||||
888 + line5
|
||||
889 line6
|
||||
890 line7
|
||||
...
|
||||
__old hunk__
|
||||
line1
|
||||
line2
|
||||
- line3
|
||||
- line4
|
||||
line5
|
||||
line6
|
||||
...
|
||||
"""
|
||||
# if the file was deleted, return a message indicating that the file was deleted
|
||||
if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED:
|
||||
@ -292,8 +392,7 @@ __old hunk__
|
||||
|
||||
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
|
||||
patch_lines = patch.splitlines()
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
new_content_lines = []
|
||||
old_content_lines = []
|
||||
match = None
|
||||
@ -307,20 +406,32 @@ __old hunk__
|
||||
if line.startswith('@@'):
|
||||
header_line = line
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines
|
||||
if match and (
|
||||
new_content_lines or old_content_lines
|
||||
): # found a new hunk, split the previous lines
|
||||
if prev_header_line:
|
||||
patch_with_lines_str += f'\n{prev_header_line}\n'
|
||||
is_plus_lines = is_minus_lines = False
|
||||
if new_content_lines:
|
||||
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
|
||||
is_plus_lines = any(
|
||||
[line.startswith('+') for line in new_content_lines]
|
||||
)
|
||||
if old_content_lines:
|
||||
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
|
||||
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
||||
is_minus_lines = any(
|
||||
[line.startswith('-') for line in old_content_lines]
|
||||
)
|
||||
if (
|
||||
is_plus_lines or is_minus_lines
|
||||
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||
patch_with_lines_str = (
|
||||
patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
||||
)
|
||||
for i, line_new in enumerate(new_content_lines):
|
||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||
if is_minus_lines:
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
|
||||
patch_with_lines_str = (
|
||||
patch_with_lines_str.rstrip() + '\n__old hunk__\n'
|
||||
)
|
||||
for line_old in old_content_lines:
|
||||
patch_with_lines_str += f"{line_old}\n"
|
||||
new_content_lines = []
|
||||
@ -335,8 +446,12 @@ __old hunk__
|
||||
elif line.startswith('-'):
|
||||
old_content_lines.append(line)
|
||||
else:
|
||||
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
|
||||
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
|
||||
if (
|
||||
not line and line_i
|
||||
): # if this line is empty and the next line is a hunk header, skip it
|
||||
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith(
|
||||
'@@'
|
||||
):
|
||||
continue
|
||||
elif line_i + 1 == len(patch_lines):
|
||||
continue
|
||||
@ -351,7 +466,9 @@ __old hunk__
|
||||
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
|
||||
if old_content_lines:
|
||||
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
|
||||
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||
if (
|
||||
is_plus_lines or is_minus_lines
|
||||
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
||||
for i, line_new in enumerate(new_content_lines):
|
||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||
@ -363,13 +480,16 @@ __old hunk__
|
||||
return patch_with_lines_str.rstrip()
|
||||
|
||||
|
||||
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]:
|
||||
def extract_hunk_lines_from_patch(
|
||||
patch: str, file_name, line_start, line_end, side
|
||||
) -> tuple[str, str]:
|
||||
try:
|
||||
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
|
||||
selected_lines = ""
|
||||
patch_lines = patch.splitlines()
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
|
||||
)
|
||||
match = None
|
||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||
skip_hunk = False
|
||||
@ -385,7 +505,9 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
|
||||
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(
|
||||
match
|
||||
)
|
||||
|
||||
# check if line range is in this hunk
|
||||
if side.lower() == 'left':
|
||||
@ -400,15 +522,26 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
|
||||
patch_with_lines_str += f'\n{header_line}\n'
|
||||
|
||||
elif not skip_hunk:
|
||||
if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end:
|
||||
if (
|
||||
side.lower() == 'right'
|
||||
and line_start <= start2 + selected_lines_num <= line_end
|
||||
):
|
||||
selected_lines += line + '\n'
|
||||
if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end:
|
||||
if (
|
||||
side.lower() == 'left'
|
||||
and start1 <= selected_lines_num + start1 <= line_end
|
||||
):
|
||||
selected_lines += line + '\n'
|
||||
patch_with_lines_str += line + '\n'
|
||||
if not line.startswith('-'): # currently we don't support /ask line for deleted lines
|
||||
if not line.startswith(
|
||||
'-'
|
||||
): # currently we don't support /ask line for deleted lines
|
||||
selected_lines_num += 1
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Failed to extract hunk lines from patch: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
return "", ""
|
||||
|
||||
return patch_with_lines_str.rstrip(), selected_lines.rstrip()
|
||||
|
||||
@ -9,10 +9,14 @@ def filter_bad_extensions(files):
|
||||
bad_extensions = get_settings().bad_extensions.default
|
||||
if get_settings().config.use_extra_bad_extensions:
|
||||
bad_extensions += get_settings().bad_extensions.extra
|
||||
return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)]
|
||||
return [
|
||||
f
|
||||
for f in files
|
||||
if f.filename is not None and is_valid_file(f.filename, bad_extensions)
|
||||
]
|
||||
|
||||
|
||||
def is_valid_file(filename:str, bad_extensions=None) -> bool:
|
||||
def is_valid_file(filename: str, bad_extensions=None) -> bool:
|
||||
if not filename:
|
||||
return False
|
||||
if not bad_extensions:
|
||||
@ -27,12 +31,16 @@ def sort_files_by_main_languages(languages: Dict, files: list):
|
||||
Sort files by their main language, put the files that are in the main language first and the rest files after
|
||||
"""
|
||||
# sort languages by their size
|
||||
languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)]
|
||||
languages_sorted_list = [
|
||||
k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)
|
||||
]
|
||||
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
|
||||
# get all extensions for the languages
|
||||
main_extensions = []
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||
language_extension_map = {
|
||||
k.lower(): v for k, v in language_extension_map_org.items()
|
||||
}
|
||||
for language in languages_sorted_list:
|
||||
if language.lower() in language_extension_map:
|
||||
main_extensions.append(language_extension_map[language.lower()])
|
||||
@ -62,7 +70,9 @@ def sort_files_by_main_languages(languages: Dict, files: list):
|
||||
if extension_str in extensions:
|
||||
tmp.append(file)
|
||||
else:
|
||||
if (file.filename not in rest_files) and (extension_str not in main_extensions_flat):
|
||||
if (file.filename not in rest_files) and (
|
||||
extension_str not in main_extensions_flat
|
||||
):
|
||||
rest_files[file.filename] = file
|
||||
if len(tmp) > 0:
|
||||
files_sorted.append({"language": lang, "files": tmp})
|
||||
|
||||
@ -7,18 +7,28 @@ from github import RateLimitExceededException
|
||||
|
||||
from utils.pr_agent.algo.file_filter import filter_ignored
|
||||
from utils.pr_agent.algo.git_patch_processing import (
|
||||
convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions)
|
||||
convert_to_hunks_with_lines_numbers,
|
||||
extend_patch,
|
||||
handle_patch_deletions,
|
||||
)
|
||||
from utils.pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.types import EDIT_TYPE
|
||||
from utils.pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model
|
||||
from utils.pr_agent.algo.utils import (
|
||||
ModelType,
|
||||
clip_tokens,
|
||||
get_max_tokens,
|
||||
get_weak_model,
|
||||
)
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
DELETED_FILES_ = "Deleted files:\n"
|
||||
|
||||
MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n"
|
||||
MORE_MODIFIED_FILES_ = (
|
||||
"Additional modified files (insufficient token budget to process):\n"
|
||||
)
|
||||
|
||||
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
|
||||
|
||||
@ -29,45 +39,59 @@ MAX_EXTRA_LINES = 10
|
||||
|
||||
def cap_and_log_extra_lines(value, direction) -> int:
|
||||
if value > MAX_EXTRA_LINES:
|
||||
get_logger().warning(f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}")
|
||||
get_logger().warning(
|
||||
f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}"
|
||||
)
|
||||
return MAX_EXTRA_LINES
|
||||
return value
|
||||
|
||||
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
model: str,
|
||||
add_line_numbers_to_hunks: bool = False,
|
||||
disable_extra_lines: bool = False,
|
||||
large_pr_handling=False,
|
||||
return_remaining_files=False):
|
||||
def get_pr_diff(
|
||||
git_provider: GitProvider,
|
||||
token_handler: TokenHandler,
|
||||
model: str,
|
||||
add_line_numbers_to_hunks: bool = False,
|
||||
disable_extra_lines: bool = False,
|
||||
large_pr_handling=False,
|
||||
return_remaining_files=False,
|
||||
):
|
||||
if disable_extra_lines:
|
||||
PATCH_EXTRA_LINES_BEFORE = 0
|
||||
PATCH_EXTRA_LINES_AFTER = 0
|
||||
else:
|
||||
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
||||
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
|
||||
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
|
||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
|
||||
PATCH_EXTRA_LINES_BEFORE, "before"
|
||||
)
|
||||
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(
|
||||
PATCH_EXTRA_LINES_AFTER, "after"
|
||||
)
|
||||
|
||||
try:
|
||||
diff_files_original = git_provider.get_diff_files()
|
||||
except RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||
get_logger().error(
|
||||
f"Rate limit exceeded for git provider API. original message {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
diff_files = filter_ignored(diff_files_original)
|
||||
if diff_files != diff_files_original:
|
||||
try:
|
||||
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
|
||||
get_logger().info(
|
||||
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
|
||||
)
|
||||
new_names = set([a.filename for a in diff_files])
|
||||
orig_names = set([a.filename for a in diff_files_original])
|
||||
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
# get pr languages
|
||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||
pr_languages = sort_files_by_main_languages(
|
||||
git_provider.get_languages(), diff_files
|
||||
)
|
||||
if pr_languages:
|
||||
try:
|
||||
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
||||
@ -76,24 +100,42 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
|
||||
# generate a standard diff string, with patch extension
|
||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks,
|
||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
|
||||
pr_languages,
|
||||
token_handler,
|
||||
add_line_numbers_to_hunks,
|
||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
|
||||
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
|
||||
)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
|
||||
f"returning full diff.")
|
||||
get_logger().info(
|
||||
f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
|
||||
f"returning full diff."
|
||||
)
|
||||
return "\n".join(patches_extended)
|
||||
|
||||
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
|
||||
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
||||
f"pruning diff.")
|
||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
|
||||
get_logger().info(
|
||||
f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
||||
f"pruning diff."
|
||||
)
|
||||
(
|
||||
patches_compressed_list,
|
||||
total_tokens_list,
|
||||
deleted_files_list,
|
||||
remaining_files_list,
|
||||
file_dict,
|
||||
files_in_patches_list,
|
||||
) = pr_generate_compressed_diff(
|
||||
pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling
|
||||
)
|
||||
|
||||
if large_pr_handling and len(patches_compressed_list) > 1:
|
||||
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
|
||||
return "" # return empty string, as we want to generate multiple patches with a different prompt
|
||||
get_logger().info(
|
||||
f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff."
|
||||
)
|
||||
return "" # return empty string, as we want to generate multiple patches with a different prompt
|
||||
|
||||
# return the first patch
|
||||
patches_compressed = patches_compressed_list[0]
|
||||
@ -144,26 +186,37 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
if deleted_list_str:
|
||||
final_diff = final_diff + "\n\n" + deleted_list_str
|
||||
|
||||
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
||||
f"deleted_list_str: {deleted_list_str}")
|
||||
get_logger().debug(
|
||||
f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
||||
f"deleted_list_str: {deleted_list_str}"
|
||||
)
|
||||
if not return_remaining_files:
|
||||
return final_diff
|
||||
else:
|
||||
return final_diff, remaining_files_list
|
||||
|
||||
|
||||
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False):
|
||||
def get_pr_diff_multiple_patchs(
|
||||
git_provider: GitProvider,
|
||||
token_handler: TokenHandler,
|
||||
model: str,
|
||||
add_line_numbers_to_hunks: bool = False,
|
||||
disable_extra_lines: bool = False,
|
||||
):
|
||||
try:
|
||||
diff_files_original = git_provider.get_diff_files()
|
||||
except RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||
get_logger().error(
|
||||
f"Rate limit exceeded for git provider API. original message {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
diff_files = filter_ignored(diff_files_original)
|
||||
if diff_files != diff_files_original:
|
||||
try:
|
||||
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
|
||||
get_logger().info(
|
||||
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
|
||||
)
|
||||
new_names = set([a.filename for a in diff_files])
|
||||
orig_names = set([a.filename for a in diff_files_original])
|
||||
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
||||
@ -171,24 +224,47 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
|
||||
pass
|
||||
|
||||
# get pr languages
|
||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||
pr_languages = sort_files_by_main_languages(
|
||||
git_provider.get_languages(), diff_files
|
||||
)
|
||||
if pr_languages:
|
||||
try:
|
||||
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True)
|
||||
(
|
||||
patches_compressed_list,
|
||||
total_tokens_list,
|
||||
deleted_files_list,
|
||||
remaining_files_list,
|
||||
file_dict,
|
||||
files_in_patches_list,
|
||||
) = pr_generate_compressed_diff(
|
||||
pr_languages,
|
||||
token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks,
|
||||
large_pr_handling=True,
|
||||
)
|
||||
|
||||
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||
return (
|
||||
patches_compressed_list,
|
||||
total_tokens_list,
|
||||
deleted_files_list,
|
||||
remaining_files_list,
|
||||
file_dict,
|
||||
files_in_patches_list,
|
||||
)
|
||||
|
||||
|
||||
def pr_generate_extended_diff(pr_languages: list,
|
||||
token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool,
|
||||
patch_extra_lines_before: int = 0,
|
||||
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
|
||||
def pr_generate_extended_diff(
|
||||
pr_languages: list,
|
||||
token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool,
|
||||
patch_extra_lines_before: int = 0,
|
||||
patch_extra_lines_after: int = 0,
|
||||
) -> Tuple[list, int, list]:
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
patches_extended = []
|
||||
patches_extended_tokens = []
|
||||
@ -200,20 +276,33 @@ def pr_generate_extended_diff(pr_languages: list,
|
||||
continue
|
||||
|
||||
# extend each patch with extra lines of context
|
||||
extended_patch = extend_patch(original_file_content_str, patch,
|
||||
patch_extra_lines_before, patch_extra_lines_after, file.filename)
|
||||
extended_patch = extend_patch(
|
||||
original_file_content_str,
|
||||
patch,
|
||||
patch_extra_lines_before,
|
||||
patch_extra_lines_after,
|
||||
file.filename,
|
||||
)
|
||||
if not extended_patch:
|
||||
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
|
||||
get_logger().warning(
|
||||
f"Failed to extend patch for file: {file.filename}"
|
||||
)
|
||||
continue
|
||||
|
||||
if add_line_numbers_to_hunks:
|
||||
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
|
||||
full_extended_patch = convert_to_hunks_with_lines_numbers(
|
||||
extended_patch, file
|
||||
)
|
||||
else:
|
||||
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
|
||||
|
||||
# add AI-summary metadata to the patch
|
||||
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
|
||||
full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch)
|
||||
if file.ai_file_summary and get_settings().get(
|
||||
"config.enable_ai_metadata", False
|
||||
):
|
||||
full_extended_patch = add_ai_summary_top_patch(
|
||||
file, full_extended_patch
|
||||
)
|
||||
|
||||
patch_tokens = token_handler.count_tokens(full_extended_patch)
|
||||
file.tokens = patch_tokens
|
||||
@ -224,9 +313,13 @@ def pr_generate_extended_diff(pr_languages: list,
|
||||
return patches_extended, total_tokens, patches_extended_tokens
|
||||
|
||||
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||
convert_hunks_to_line_numbers: bool,
|
||||
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
|
||||
def pr_generate_compressed_diff(
|
||||
top_langs: list,
|
||||
token_handler: TokenHandler,
|
||||
model: str,
|
||||
convert_hunks_to_line_numbers: bool,
|
||||
large_pr_handling: bool,
|
||||
) -> Tuple[list, list, list, list, dict, list]:
|
||||
deleted_files_list = []
|
||||
|
||||
# sort each one of the languages in top_langs by the number of tokens in the diff
|
||||
@ -244,8 +337,13 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
continue
|
||||
|
||||
# removing delete-only hunks
|
||||
patch = handle_patch_deletions(patch, original_file_content_str,
|
||||
new_file_content_str, file.filename, file.edit_type)
|
||||
patch = handle_patch_deletions(
|
||||
patch,
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
file.filename,
|
||||
file.edit_type,
|
||||
)
|
||||
if patch is None:
|
||||
if file.filename not in deleted_files_list:
|
||||
deleted_files_list.append(file.filename)
|
||||
@ -259,30 +357,54 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
# patch = add_ai_summary_top_patch(file, patch)
|
||||
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type}
|
||||
file_dict[file.filename] = {
|
||||
'patch': patch,
|
||||
'tokens': new_patch_tokens,
|
||||
'edit_type': file.edit_type,
|
||||
}
|
||||
|
||||
max_tokens_model = get_max_tokens(model)
|
||||
|
||||
# first iteration
|
||||
files_in_patches_list = []
|
||||
remaining_files_list = [file.filename for file in sorted_files]
|
||||
patches_list =[]
|
||||
remaining_files_list = [file.filename for file in sorted_files]
|
||||
patches_list = []
|
||||
total_tokens_list = []
|
||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict,
|
||||
max_tokens_model, remaining_files_list, token_handler)
|
||||
(
|
||||
total_tokens,
|
||||
patches,
|
||||
remaining_files_list,
|
||||
files_in_patch_list,
|
||||
) = generate_full_patch(
|
||||
convert_hunks_to_line_numbers,
|
||||
file_dict,
|
||||
max_tokens_model,
|
||||
remaining_files_list,
|
||||
token_handler,
|
||||
)
|
||||
patches_list.append(patches)
|
||||
total_tokens_list.append(total_tokens)
|
||||
files_in_patches_list.append(files_in_patch_list)
|
||||
|
||||
# additional iterations (if needed)
|
||||
if large_pr_handling:
|
||||
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
|
||||
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
|
||||
NUMBER_OF_ALLOWED_ITERATIONS = (
|
||||
get_settings().pr_description.max_ai_calls - 1
|
||||
) # one more call is to summarize
|
||||
for i in range(NUMBER_OF_ALLOWED_ITERATIONS - 1):
|
||||
if remaining_files_list:
|
||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
|
||||
file_dict,
|
||||
max_tokens_model,
|
||||
remaining_files_list, token_handler)
|
||||
(
|
||||
total_tokens,
|
||||
patches,
|
||||
remaining_files_list,
|
||||
files_in_patch_list,
|
||||
) = generate_full_patch(
|
||||
convert_hunks_to_line_numbers,
|
||||
file_dict,
|
||||
max_tokens_model,
|
||||
remaining_files_list,
|
||||
token_handler,
|
||||
)
|
||||
if patches:
|
||||
patches_list.append(patches)
|
||||
total_tokens_list.append(total_tokens)
|
||||
@ -290,11 +412,24 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
else:
|
||||
break
|
||||
|
||||
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||
return (
|
||||
patches_list,
|
||||
total_tokens_list,
|
||||
deleted_files_list,
|
||||
remaining_files_list,
|
||||
file_dict,
|
||||
files_in_patches_list,
|
||||
)
|
||||
|
||||
|
||||
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler):
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
def generate_full_patch(
|
||||
convert_hunks_to_line_numbers,
|
||||
file_dict,
|
||||
max_tokens_model,
|
||||
remaining_files_list_prev,
|
||||
token_handler,
|
||||
):
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
patches = []
|
||||
remaining_files_list_new = []
|
||||
files_in_patch_list = []
|
||||
@ -312,7 +447,10 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
|
||||
continue
|
||||
|
||||
# If the patch is too large, just show the file name
|
||||
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
if (
|
||||
total_tokens + new_patch_tokens
|
||||
> max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||
):
|
||||
# Current logic is to skip the patch if it's too large
|
||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||
# until we meet the requirements
|
||||
@ -334,7 +472,9 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
|
||||
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
|
||||
async def retry_with_fallback_models(
|
||||
f: Callable, model_type: ModelType = ModelType.REGULAR
|
||||
):
|
||||
all_models = _get_all_models(model_type)
|
||||
all_deployments = _get_all_deployments(all_models)
|
||||
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
||||
@ -347,11 +487,11 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
|
||||
get_settings().set("openai.deployment_id", deployment_id)
|
||||
return await f(model)
|
||||
except:
|
||||
get_logger().warning(
|
||||
f"Failed to generate prediction with {model}"
|
||||
)
|
||||
get_logger().warning(f"Failed to generate prediction with {model}")
|
||||
if i == len(all_models) - 1: # If it's the last iteration
|
||||
raise Exception(f"Failed to generate prediction with any model of {all_models}")
|
||||
raise Exception(
|
||||
f"Failed to generate prediction with any model of {all_models}"
|
||||
)
|
||||
|
||||
|
||||
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
|
||||
@ -374,17 +514,21 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
|
||||
if fallback_deployments:
|
||||
all_deployments = [deployment_id] + fallback_deployments
|
||||
if len(all_deployments) < len(all_models):
|
||||
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
|
||||
f"is less than the number of models ({len(all_models)})")
|
||||
raise ValueError(
|
||||
f"The number of deployments ({len(all_deployments)}) "
|
||||
f"is less than the number of models ({len(all_models)})"
|
||||
)
|
||||
else:
|
||||
all_deployments = [deployment_id] * len(all_models)
|
||||
return all_deployments
|
||||
|
||||
|
||||
def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
token_handler: TokenHandler,
|
||||
model: str,
|
||||
max_calls: int = 5) -> List[str]:
|
||||
def get_pr_multi_diffs(
|
||||
git_provider: GitProvider,
|
||||
token_handler: TokenHandler,
|
||||
model: str,
|
||||
max_calls: int = 5,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
|
||||
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
|
||||
@ -404,13 +548,17 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
try:
|
||||
diff_files = git_provider.get_diff_files()
|
||||
except RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||
get_logger().error(
|
||||
f"Rate limit exceeded for git provider API. original message {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
diff_files = filter_ignored(diff_files)
|
||||
|
||||
# Sort files by main language
|
||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||
pr_languages = sort_files_by_main_languages(
|
||||
git_provider.get_languages(), diff_files
|
||||
)
|
||||
|
||||
# Sort files within each language group by tokens in descending order
|
||||
sorted_files = []
|
||||
@ -420,14 +568,19 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
# Get the maximum number of extra lines before and after the patch
|
||||
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
||||
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
|
||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
|
||||
PATCH_EXTRA_LINES_BEFORE, "before"
|
||||
)
|
||||
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
|
||||
|
||||
# try first a single run with standard diff string, with patch extension, and no deletions
|
||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks=True,
|
||||
pr_languages,
|
||||
token_handler,
|
||||
add_line_numbers_to_hunks=True,
|
||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
|
||||
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
|
||||
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
|
||||
)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||
@ -450,27 +603,50 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
continue
|
||||
|
||||
# Remove delete-only hunks
|
||||
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type)
|
||||
patch = handle_patch_deletions(
|
||||
patch,
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
file.filename,
|
||||
file.edit_type,
|
||||
)
|
||||
if patch is None:
|
||||
continue
|
||||
|
||||
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
||||
# add AI-summary metadata to the patch
|
||||
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
|
||||
if file.ai_file_summary and get_settings().get(
|
||||
"config.enable_ai_metadata", False
|
||||
):
|
||||
patch = add_ai_summary_top_patch(file, patch)
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
|
||||
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
if (
|
||||
patch
|
||||
and (token_handler.prompt_tokens + new_patch_tokens)
|
||||
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||
):
|
||||
if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
elif get_settings().config.get('large_patch_policy') == 'clip':
|
||||
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
|
||||
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
|
||||
delta_tokens = (
|
||||
get_max_tokens(model)
|
||||
- OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||
- token_handler.prompt_tokens
|
||||
)
|
||||
patch_clipped = clip_tokens(
|
||||
patch,
|
||||
delta_tokens,
|
||||
delete_last_line=True,
|
||||
num_input_tokens=new_patch_tokens,
|
||||
)
|
||||
new_patch_tokens = token_handler.count_tokens(patch_clipped)
|
||||
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
if (
|
||||
patch_clipped
|
||||
and (token_handler.prompt_tokens + new_patch_tokens)
|
||||
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||
):
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
else:
|
||||
@ -480,13 +656,16 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
|
||||
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
|
||||
if patch and (
|
||||
total_tokens + new_patch_tokens
|
||||
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||
):
|
||||
final_diff = "\n".join(patches)
|
||||
final_diff_list.append(final_diff)
|
||||
patches = []
|
||||
total_tokens = token_handler.prompt_tokens
|
||||
call_number += 1
|
||||
if call_number > max_calls: # avoid creating new patches
|
||||
if call_number > max_calls: # avoid creating new patches
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Reached max calls ({max_calls})")
|
||||
break
|
||||
@ -497,7 +676,9 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
patches.append(patch)
|
||||
total_tokens += new_patch_tokens
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
||||
get_logger().info(
|
||||
f"Tokens: {total_tokens}, last filename: {file.filename}"
|
||||
)
|
||||
|
||||
# Add the last chunk
|
||||
if patches:
|
||||
@ -515,7 +696,10 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
|
||||
if not pr_description_files:
|
||||
get_logger().warning(f"PR description files are empty.")
|
||||
return
|
||||
available_files = {pr_file['full_file_name'].strip(): pr_file for pr_file in pr_description_files}
|
||||
available_files = {
|
||||
pr_file['full_file_name'].strip(): pr_file
|
||||
for pr_file in pr_description_files
|
||||
}
|
||||
diff_files = git_provider.get_diff_files()
|
||||
found_any_match = False
|
||||
for file in diff_files:
|
||||
@ -524,11 +708,15 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
|
||||
file.ai_file_summary = available_files[filename]
|
||||
found_any_match = True
|
||||
if not found_any_match:
|
||||
get_logger().error(f"Failed to find any matching files between PR description and diff files.",
|
||||
artifact={"pr_description_files": pr_description_files})
|
||||
get_logger().error(
|
||||
f"Failed to find any matching files between PR description and diff files.",
|
||||
artifact={"pr_description_files": pr_description_files},
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to add AI metadata to diff files: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Failed to add AI metadata to diff files: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
|
||||
def add_ai_summary_top_patch(file, full_extended_patch):
|
||||
@ -537,14 +725,18 @@ def add_ai_summary_top_patch(file, full_extended_patch):
|
||||
full_extended_patch_lines = full_extended_patch.split("\n")
|
||||
for i, line in enumerate(full_extended_patch_lines):
|
||||
if line.startswith("## File:") or line.startswith("## file:"):
|
||||
full_extended_patch_lines.insert(i + 1,
|
||||
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}")
|
||||
full_extended_patch_lines.insert(
|
||||
i + 1,
|
||||
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}",
|
||||
)
|
||||
full_extended_patch = "\n".join(full_extended_patch_lines)
|
||||
return full_extended_patch
|
||||
|
||||
# if no '## File: ...' was found
|
||||
return full_extended_patch
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to add AI summary to the top of the patch: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Failed to add AI summary to the top of the patch: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
return full_extended_patch
|
||||
|
||||
@ -15,12 +15,17 @@ class TokenEncoder:
|
||||
@classmethod
|
||||
def get_token_encoder(cls):
|
||||
model = get_settings().config.model
|
||||
if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance
|
||||
if (
|
||||
cls._encoder_instance is None or model != cls._model
|
||||
): # Check without acquiring the lock for performance
|
||||
with cls._lock: # Lock acquisition to ensure thread safety
|
||||
if cls._encoder_instance is None or model != cls._model:
|
||||
cls._model = model
|
||||
cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding(
|
||||
"cl100k_base")
|
||||
cls._encoder_instance = (
|
||||
encoding_for_model(cls._model)
|
||||
if "gpt" in cls._model
|
||||
else get_encoding("cl100k_base")
|
||||
)
|
||||
return cls._encoder_instance
|
||||
|
||||
|
||||
@ -49,7 +54,9 @@ class TokenHandler:
|
||||
"""
|
||||
self.encoder = TokenEncoder.get_token_encoder()
|
||||
if pr is not None:
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
self.prompt_tokens = self._get_system_user_tokens(
|
||||
pr, self.encoder, vars, system, user
|
||||
)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
"""
|
||||
|
||||
@ -41,10 +41,12 @@ class Range(BaseModel):
|
||||
column_start: int = -1
|
||||
column_end: int = -1
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
REGULAR = "regular"
|
||||
WEAK = "weak"
|
||||
|
||||
|
||||
class PRReviewHeader(str, Enum):
|
||||
REGULAR = "## PR 评审指南"
|
||||
INCREMENTAL = "## 增量 PR 评审指南"
|
||||
@ -57,7 +59,9 @@ class PRDescriptionHeader(str, Enum):
|
||||
def get_setting(key: str) -> Any:
|
||||
try:
|
||||
key = key.upper()
|
||||
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
|
||||
return context.get("settings", global_settings).get(
|
||||
key, global_settings.get(key, None)
|
||||
)
|
||||
except Exception:
|
||||
return global_settings.get(key, None)
|
||||
|
||||
@ -72,14 +76,29 @@ def emphasize_header(text: str, only_markdown=False, reference_link=None) -> str
|
||||
# Everything before the colon (inclusive) is wrapped in <strong> tags
|
||||
if only_markdown:
|
||||
if reference_link:
|
||||
transformed_string = f"[**{text[:colon_position + 1]}**]({reference_link})\n" + text[colon_position + 1:]
|
||||
transformed_string = (
|
||||
f"[**{text[:colon_position + 1]}**]({reference_link})\n"
|
||||
+ text[colon_position + 1 :]
|
||||
)
|
||||
else:
|
||||
transformed_string = f"**{text[:colon_position + 1]}**\n" + text[colon_position + 1:]
|
||||
transformed_string = (
|
||||
f"**{text[:colon_position + 1]}**\n"
|
||||
+ text[colon_position + 1 :]
|
||||
)
|
||||
else:
|
||||
if reference_link:
|
||||
transformed_string = f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>" + text[colon_position + 1:]
|
||||
transformed_string = (
|
||||
f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>"
|
||||
+ text[colon_position + 1 :]
|
||||
)
|
||||
else:
|
||||
transformed_string = "<strong>" + text[:colon_position + 1] + "</strong>" +'<br>' + text[colon_position + 1:]
|
||||
transformed_string = (
|
||||
"<strong>"
|
||||
+ text[: colon_position + 1]
|
||||
+ "</strong>"
|
||||
+ '<br>'
|
||||
+ text[colon_position + 1 :]
|
||||
)
|
||||
else:
|
||||
# If there's no ": ", return the original string
|
||||
transformed_string = text
|
||||
@ -101,11 +120,14 @@ def unique_strings(input_list: List[str]) -> List[str]:
|
||||
seen.add(item)
|
||||
return unique_list
|
||||
|
||||
def convert_to_markdown_v2(output_data: dict,
|
||||
gfm_supported: bool = True,
|
||||
incremental_review=None,
|
||||
git_provider=None,
|
||||
files=None) -> str:
|
||||
|
||||
def convert_to_markdown_v2(
|
||||
output_data: dict,
|
||||
gfm_supported: bool = True,
|
||||
incremental_review=None,
|
||||
git_provider=None,
|
||||
files=None,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary of data into markdown format.
|
||||
Args:
|
||||
@ -183,7 +205,9 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
else:
|
||||
markdown_text += f"### {emoji} PR 包含测试\n\n"
|
||||
elif 'ticket compliance check' in key_nice.lower():
|
||||
markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported)
|
||||
markdown_text = ticket_markdown_logic(
|
||||
emoji, markdown_text, value, gfm_supported
|
||||
)
|
||||
elif 'security concerns' in key_nice.lower():
|
||||
if gfm_supported:
|
||||
markdown_text += f"<tr><td>"
|
||||
@ -220,7 +244,9 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
if gfm_supported:
|
||||
markdown_text += f"<tr><td>"
|
||||
# markdown_text += f"{emoji} <strong>{key_nice}</strong><br><br>\n\n"
|
||||
markdown_text += f"{emoji} <strong>建议评审的重点领域</strong><br><br>\n\n"
|
||||
markdown_text += (
|
||||
f"{emoji} <strong>建议评审的重点领域</strong><br><br>\n\n"
|
||||
)
|
||||
else:
|
||||
markdown_text += f"### {emoji} 建议评审的重点领域\n\n#### \n"
|
||||
for i, issue in enumerate(issues):
|
||||
@ -235,9 +261,13 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
start_line = int(str(issue.get('start_line', 0)).strip())
|
||||
end_line = int(str(issue.get('end_line', 0)).strip())
|
||||
|
||||
relevant_lines_str = extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=True)
|
||||
relevant_lines_str = extract_relevant_lines_str(
|
||||
end_line, files, relevant_file, start_line, dedent=True
|
||||
)
|
||||
if git_provider:
|
||||
reference_link = git_provider.get_line_link(relevant_file, start_line, end_line)
|
||||
reference_link = git_provider.get_line_link(
|
||||
relevant_file, start_line, end_line
|
||||
)
|
||||
else:
|
||||
reference_link = None
|
||||
|
||||
@ -256,7 +286,9 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
issue_str = f"**{issue_header}**\n\n{issue_content}\n\n"
|
||||
markdown_text += f"{issue_str}\n\n"
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to process 'Recommended focus areas for review': {e}")
|
||||
get_logger().exception(
|
||||
f"Failed to process 'Recommended focus areas for review': {e}"
|
||||
)
|
||||
if gfm_supported:
|
||||
markdown_text += f"</td></tr>\n"
|
||||
else:
|
||||
@ -273,7 +305,9 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
return markdown_text
|
||||
|
||||
|
||||
def extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=False) -> str:
|
||||
def extract_relevant_lines_str(
|
||||
end_line, files, relevant_file, start_line, dedent=False
|
||||
) -> str:
|
||||
"""
|
||||
Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content.
|
||||
"""
|
||||
@ -286,10 +320,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
|
||||
if not file.head_file:
|
||||
# as a fallback, extract relevant lines directly from patch
|
||||
patch = file.patch
|
||||
get_logger().info(f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead")
|
||||
_, selected_lines = extract_hunk_lines_from_patch(patch, file.filename, start_line, end_line,side='right')
|
||||
get_logger().info(
|
||||
f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead"
|
||||
)
|
||||
_, selected_lines = extract_hunk_lines_from_patch(
|
||||
patch, file.filename, start_line, end_line, side='right'
|
||||
)
|
||||
if not selected_lines:
|
||||
get_logger().error(f"Failed to extract relevant lines from patch: {file.filename}")
|
||||
get_logger().error(
|
||||
f"Failed to extract relevant lines from patch: {file.filename}"
|
||||
)
|
||||
return ""
|
||||
# filter out '-' lines
|
||||
relevant_lines_str = ""
|
||||
@ -299,12 +339,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
|
||||
relevant_lines_str += line[1:] + '\n'
|
||||
else:
|
||||
relevant_file_lines = file.head_file.splitlines()
|
||||
relevant_lines_str = "\n".join(relevant_file_lines[start_line - 1:end_line])
|
||||
relevant_lines_str = "\n".join(
|
||||
relevant_file_lines[start_line - 1 : end_line]
|
||||
)
|
||||
|
||||
if dedent and relevant_lines_str:
|
||||
# Remove the longest leading string of spaces and tabs common to all lines.
|
||||
relevant_lines_str = textwrap.dedent(relevant_lines_str)
|
||||
relevant_lines_str = f"```{file.language}\n{relevant_lines_str}\n```"
|
||||
relevant_lines_str = (
|
||||
f"```{file.language}\n{relevant_lines_str}\n```"
|
||||
)
|
||||
break
|
||||
|
||||
return relevant_lines_str
|
||||
@ -325,14 +369,21 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
||||
ticket_url = ticket_analysis.get('ticket_url', '').strip()
|
||||
explanation = ''
|
||||
ticket_compliance_level = '' # Individual ticket compliance
|
||||
fully_compliant_str = ticket_analysis.get('fully_compliant_requirements', '').strip()
|
||||
not_compliant_str = ticket_analysis.get('not_compliant_requirements', '').strip()
|
||||
requires_further_human_verification = ticket_analysis.get('requires_further_human_verification',
|
||||
'').strip()
|
||||
fully_compliant_str = ticket_analysis.get(
|
||||
'fully_compliant_requirements', ''
|
||||
).strip()
|
||||
not_compliant_str = ticket_analysis.get(
|
||||
'not_compliant_requirements', ''
|
||||
).strip()
|
||||
requires_further_human_verification = ticket_analysis.get(
|
||||
'requires_further_human_verification', ''
|
||||
).strip()
|
||||
|
||||
if not fully_compliant_str and not not_compliant_str:
|
||||
get_logger().debug(f"Ticket compliance has no requirements",
|
||||
artifact={'ticket_url': ticket_url})
|
||||
get_logger().debug(
|
||||
f"Ticket compliance has no requirements",
|
||||
artifact={'ticket_url': ticket_url},
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate individual ticket compliance level
|
||||
@ -353,19 +404,27 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
||||
|
||||
# build compliance string
|
||||
if fully_compliant_str:
|
||||
explanation += f"Compliant requirements:\n\n{fully_compliant_str}\n\n"
|
||||
explanation += (
|
||||
f"Compliant requirements:\n\n{fully_compliant_str}\n\n"
|
||||
)
|
||||
if not_compliant_str:
|
||||
explanation += f"Non-compliant requirements:\n\n{not_compliant_str}\n\n"
|
||||
explanation += (
|
||||
f"Non-compliant requirements:\n\n{not_compliant_str}\n\n"
|
||||
)
|
||||
if requires_further_human_verification:
|
||||
explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n"
|
||||
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n"
|
||||
|
||||
# for debugging
|
||||
if requires_further_human_verification:
|
||||
get_logger().debug(f"Ticket compliance requires further human verification",
|
||||
artifact={'ticket_url': ticket_url,
|
||||
'requires_further_human_verification': requires_further_human_verification,
|
||||
'compliance_level': ticket_compliance_level})
|
||||
get_logger().debug(
|
||||
f"Ticket compliance requires further human verification",
|
||||
artifact={
|
||||
'ticket_url': ticket_url,
|
||||
'requires_further_human_verification': requires_further_human_verification,
|
||||
'compliance_level': ticket_compliance_level,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to process ticket compliance: {e}")
|
||||
@ -381,7 +440,10 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
||||
compliance_emoji = '✅'
|
||||
elif any(level == 'Not compliant' for level in all_compliance_levels):
|
||||
# If there's a mix of compliant and non-compliant tickets
|
||||
if any(level in ['Fully compliant', 'PR Code Verified'] for level in all_compliance_levels):
|
||||
if any(
|
||||
level in ['Fully compliant', 'PR Code Verified']
|
||||
for level in all_compliance_levels
|
||||
):
|
||||
compliance_level = 'Partially compliant'
|
||||
compliance_emoji = '🔶'
|
||||
else:
|
||||
@ -395,7 +457,9 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
||||
compliance_emoji = '✅'
|
||||
|
||||
# Set extra statistics outside the ticket loop
|
||||
get_settings().set('config.extra_statistics', {'compliance_level': compliance_level})
|
||||
get_settings().set(
|
||||
'config.extra_statistics', {'compliance_level': compliance_level}
|
||||
)
|
||||
|
||||
# editing table row for ticket compliance analysis
|
||||
if gfm_supported:
|
||||
@ -425,7 +489,9 @@ def process_can_be_split(emoji, value):
|
||||
for i, split in enumerate(value):
|
||||
title = split.get('title', '')
|
||||
relevant_files = split.get('relevant_files', [])
|
||||
markdown_text += f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n"
|
||||
markdown_text += (
|
||||
f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n"
|
||||
)
|
||||
markdown_text += f"___\n\n相关文件:\n\n"
|
||||
for file in relevant_files:
|
||||
markdown_text += f"- {file}\n"
|
||||
@ -464,7 +530,9 @@ def process_can_be_split(emoji, value):
|
||||
return markdown_text
|
||||
|
||||
|
||||
def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool = True) -> str:
|
||||
def parse_code_suggestion(
|
||||
code_suggestion: dict, i: int = 0, gfm_supported: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary of data into markdown format.
|
||||
|
||||
@ -484,15 +552,19 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
|
||||
markdown_text += f"<tr><td>相关文件</td><td>{relevant_file}</td></tr>"
|
||||
# continue
|
||||
elif sub_key.lower() == 'suggestion':
|
||||
markdown_text += (f"<tr><td>{sub_key} </td>"
|
||||
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>")
|
||||
markdown_text += (
|
||||
f"<tr><td>{sub_key} </td>"
|
||||
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>"
|
||||
)
|
||||
elif sub_key.lower() == 'relevant_line':
|
||||
markdown_text += f"<tr><td>相关行</td>"
|
||||
sub_value_list = sub_value.split('](')
|
||||
relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
|
||||
if len(sub_value_list) > 1:
|
||||
link = sub_value_list[1].rstrip(')').strip('`')
|
||||
markdown_text += f"<td><a href='{link}'>{relevant_line}</a></td>"
|
||||
markdown_text += (
|
||||
f"<td><a href='{link}'>{relevant_line}</a></td>"
|
||||
)
|
||||
else:
|
||||
markdown_text += f"<td>{relevant_line}</td>"
|
||||
markdown_text += "</tr>"
|
||||
@ -505,11 +577,14 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
|
||||
for sub_key, sub_value in code_suggestion.items():
|
||||
if isinstance(sub_key, str):
|
||||
sub_key = sub_key.rstrip()
|
||||
if isinstance(sub_value,str):
|
||||
if isinstance(sub_value, str):
|
||||
sub_value = sub_value.rstrip()
|
||||
if isinstance(sub_value, dict): # "code example"
|
||||
markdown_text += f" - **{sub_key}:**\n"
|
||||
for code_key, code_value in sub_value.items(): # 'before' and 'after' code
|
||||
for (
|
||||
code_key,
|
||||
code_value,
|
||||
) in sub_value.items(): # 'before' and 'after' code
|
||||
code_str = f"```\n{code_value}\n```"
|
||||
code_str_indented = textwrap.indent(code_str, ' ')
|
||||
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
|
||||
@ -520,7 +595,9 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
|
||||
markdown_text += f" **{sub_key}:** {sub_value} \n"
|
||||
if "relevant_line" not in sub_key.lower(): # nicer presentation
|
||||
# markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab
|
||||
markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker
|
||||
markdown_text = (
|
||||
markdown_text.rstrip('\n') + " \n"
|
||||
) # works for gitlab and bitbucker
|
||||
|
||||
markdown_text += "\n"
|
||||
return markdown_text
|
||||
@ -561,9 +638,15 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
|
||||
else:
|
||||
closing_bracket = "]}}"
|
||||
|
||||
if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \
|
||||
(review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) :
|
||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
||||
if (
|
||||
review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0
|
||||
) or (
|
||||
review.rfind("'Code suggestions': [") > 0
|
||||
or review.rfind('"Code suggestions": [') > 0
|
||||
):
|
||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][
|
||||
-1
|
||||
] - 1
|
||||
valid_json = False
|
||||
iter_count = 0
|
||||
|
||||
@ -574,7 +657,9 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
|
||||
review = review[:last_code_suggestion_ind].strip() + closing_bracket
|
||||
except json.decoder.JSONDecodeError:
|
||||
review = review[:last_code_suggestion_ind]
|
||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
||||
last_code_suggestion_ind = [
|
||||
m.end() for m in re.finditer(r"\}\s*,", review)
|
||||
][-1] - 1
|
||||
iter_count += 1
|
||||
|
||||
if not valid_json:
|
||||
@ -629,7 +714,12 @@ def convert_str_to_datetime(date_str):
|
||||
return datetime.strptime(date_str, datetime_format)
|
||||
|
||||
|
||||
def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str, show_warning: bool = True) -> str:
|
||||
def load_large_diff(
|
||||
filename,
|
||||
new_file_content_str: str,
|
||||
original_file_content_str: str,
|
||||
show_warning: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
|
||||
input.
|
||||
@ -640,10 +730,14 @@ def load_large_diff(filename, new_file_content_str: str, original_file_content_s
|
||||
try:
|
||||
original_file_content_str = (original_file_content_str or "").rstrip() + "\n"
|
||||
new_file_content_str = (new_file_content_str or "").rstrip() + "\n"
|
||||
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
|
||||
new_file_content_str.splitlines(keepends=True))
|
||||
diff = difflib.unified_diff(
|
||||
original_file_content_str.splitlines(keepends=True),
|
||||
new_file_content_str.splitlines(keepends=True),
|
||||
)
|
||||
if get_settings().config.verbosity_level >= 2 and show_warning:
|
||||
get_logger().info(f"File was modified, but no patch was found. Manually creating patch: {filename}.")
|
||||
get_logger().info(
|
||||
f"File was modified, but no patch was found. Manually creating patch: {filename}."
|
||||
)
|
||||
patch = ''.join(diff)
|
||||
return patch
|
||||
except Exception as e:
|
||||
@ -693,42 +787,68 @@ def _fix_key_value(key: str, value: str):
|
||||
try:
|
||||
value = yaml.safe_load(value)
|
||||
except Exception as e:
|
||||
get_logger().debug(f"Failed to parse YAML for config override {key}={value}", exc_info=e)
|
||||
get_logger().debug(
|
||||
f"Failed to parse YAML for config override {key}={value}", exc_info=e
|
||||
)
|
||||
return key, value
|
||||
|
||||
|
||||
def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="") -> dict:
|
||||
response_text = response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```')
|
||||
def load_yaml(
|
||||
response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key=""
|
||||
) -> dict:
|
||||
response_text = (
|
||||
response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```')
|
||||
)
|
||||
try:
|
||||
data = yaml.safe_load(response_text)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Initial failure to parse AI prediction: {e}")
|
||||
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key)
|
||||
data = try_fix_yaml(
|
||||
response_text,
|
||||
keys_fix_yaml=keys_fix_yaml,
|
||||
first_key=first_key,
|
||||
last_key=last_key,
|
||||
)
|
||||
if not data:
|
||||
get_logger().error(f"Failed to parse AI prediction after fallbacks",
|
||||
artifact={'response_text': response_text})
|
||||
get_logger().error(
|
||||
f"Failed to parse AI prediction after fallbacks",
|
||||
artifact={'response_text': response_text},
|
||||
)
|
||||
else:
|
||||
get_logger().info(f"Successfully parsed AI prediction after fallbacks",
|
||||
artifact={'response_text': response_text})
|
||||
get_logger().info(
|
||||
f"Successfully parsed AI prediction after fallbacks",
|
||||
artifact={'response_text': response_text},
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
def try_fix_yaml(response_text: str,
|
||||
keys_fix_yaml: List[str] = [],
|
||||
first_key="",
|
||||
last_key="",) -> dict:
|
||||
def try_fix_yaml(
|
||||
response_text: str,
|
||||
keys_fix_yaml: List[str] = [],
|
||||
first_key="",
|
||||
last_key="",
|
||||
) -> dict:
|
||||
response_text_lines = response_text.split('\n')
|
||||
|
||||
keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:']
|
||||
keys_yaml = [
|
||||
'relevant line:',
|
||||
'suggestion content:',
|
||||
'relevant file:',
|
||||
'existing code:',
|
||||
'improved code:',
|
||||
]
|
||||
keys_yaml = keys_yaml + keys_fix_yaml
|
||||
# first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...'
|
||||
response_text_lines_copy = response_text_lines.copy()
|
||||
for i in range(0, len(response_text_lines_copy)):
|
||||
for key in keys_yaml:
|
||||
if key in response_text_lines_copy[i] and not '|' in response_text_lines_copy[i]:
|
||||
response_text_lines_copy[i] = response_text_lines_copy[i].replace(f'{key}',
|
||||
f'{key} |\n ')
|
||||
if (
|
||||
key in response_text_lines_copy[i]
|
||||
and not '|' in response_text_lines_copy[i]
|
||||
):
|
||||
response_text_lines_copy[i] = response_text_lines_copy[i].replace(
|
||||
f'{key}', f'{key} |\n '
|
||||
)
|
||||
try:
|
||||
data = yaml.safe_load('\n'.join(response_text_lines_copy))
|
||||
get_logger().info(f"Successfully parsed AI prediction after adding |-\n")
|
||||
@ -743,22 +863,26 @@ def try_fix_yaml(response_text: str,
|
||||
snippet_text = snippet.group()
|
||||
try:
|
||||
data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`'))
|
||||
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
|
||||
get_logger().info(
|
||||
f"Successfully parsed AI prediction after extracting yaml snippet"
|
||||
)
|
||||
return data
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# third fallback - try to remove leading and trailing curly brackets
|
||||
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
||||
response_text_copy = (
|
||||
response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
||||
)
|
||||
try:
|
||||
data = yaml.safe_load(response_text_copy)
|
||||
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
|
||||
get_logger().info(
|
||||
f"Successfully parsed AI prediction after removing curly brackets"
|
||||
)
|
||||
return data
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# forth fallback - try to extract yaml snippet by 'first_key' and 'last_key'
|
||||
# note that 'last_key' can be in practice a key that is not the last key in the yaml snippet.
|
||||
# it just needs to be some inner key, so we can look for newlines after it
|
||||
@ -767,13 +891,23 @@ def try_fix_yaml(response_text: str,
|
||||
if index_start == -1:
|
||||
index_start = response_text.find(f"{first_key}:")
|
||||
index_last_code = response_text.rfind(f"{last_key}:")
|
||||
index_end = response_text.find("\n\n", index_last_code) # look for newlines after last_key
|
||||
index_end = response_text.find(
|
||||
"\n\n", index_last_code
|
||||
) # look for newlines after last_key
|
||||
if index_end == -1:
|
||||
index_end = len(response_text)
|
||||
response_text_copy = response_text[index_start:index_end].strip().strip('```yaml').strip('`').strip()
|
||||
response_text_copy = (
|
||||
response_text[index_start:index_end]
|
||||
.strip()
|
||||
.strip('```yaml')
|
||||
.strip('`')
|
||||
.strip()
|
||||
)
|
||||
try:
|
||||
data = yaml.safe_load(response_text_copy)
|
||||
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
|
||||
get_logger().info(
|
||||
f"Successfully parsed AI prediction after extracting yaml snippet"
|
||||
)
|
||||
return data
|
||||
except:
|
||||
pass
|
||||
@ -784,7 +918,9 @@ def try_fix_yaml(response_text: str,
|
||||
response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:]
|
||||
try:
|
||||
data = yaml.safe_load('\n'.join(response_text_lines_copy))
|
||||
get_logger().info(f"Successfully parsed AI prediction after removing leading '+'")
|
||||
get_logger().info(
|
||||
f"Successfully parsed AI prediction after removing leading '+'"
|
||||
)
|
||||
return data
|
||||
except:
|
||||
pass
|
||||
@ -794,7 +930,9 @@ def try_fix_yaml(response_text: str,
|
||||
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
|
||||
try:
|
||||
data = yaml.safe_load(response_text_lines_tmp)
|
||||
get_logger().info(f"Successfully parsed AI prediction after removing {i} lines")
|
||||
get_logger().info(
|
||||
f"Successfully parsed AI prediction after removing {i} lines"
|
||||
)
|
||||
return data
|
||||
except:
|
||||
pass
|
||||
@ -820,11 +958,14 @@ def set_custom_labels(variables, git_provider=None):
|
||||
for k, v in labels.items():
|
||||
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
|
||||
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
|
||||
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
|
||||
variables[
|
||||
"custom_labels_class"
|
||||
] += f"\n {k.lower().replace(' ', '_')} = {description}"
|
||||
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
|
||||
counter += 1
|
||||
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
|
||||
|
||||
|
||||
def get_user_labels(current_labels: List[str] = None):
|
||||
"""
|
||||
Only keep labels that has been added by the user
|
||||
@ -866,14 +1007,22 @@ def get_max_tokens(model):
|
||||
elif settings.config.custom_model_max_tokens > 0:
|
||||
max_tokens_model = settings.config.custom_model_max_tokens
|
||||
else:
|
||||
raise Exception(f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens")
|
||||
raise Exception(
|
||||
f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens"
|
||||
)
|
||||
|
||||
if settings.config.max_model_tokens and settings.config.max_model_tokens > 0:
|
||||
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
|
||||
return max_tokens_model
|
||||
|
||||
|
||||
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str:
|
||||
def clip_tokens(
|
||||
text: str,
|
||||
max_tokens: int,
|
||||
add_three_dots=True,
|
||||
num_input_tokens=None,
|
||||
delete_last_line=False,
|
||||
) -> str:
|
||||
"""
|
||||
Clip the number of tokens in a string to a maximum number of tokens.
|
||||
|
||||
@ -909,14 +1058,15 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token
|
||||
clipped_text = clipped_text.rsplit('\n', 1)[0]
|
||||
if add_three_dots:
|
||||
clipped_text += "\n...(truncated)"
|
||||
else: # if the text is empty
|
||||
clipped_text = ""
|
||||
else: # if the text is empty
|
||||
clipped_text = ""
|
||||
|
||||
return clipped_text
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to clip tokens: {e}")
|
||||
return text
|
||||
|
||||
|
||||
def replace_code_tags(text):
|
||||
"""
|
||||
Replace odd instances of ` with <code> and even instances of ` with </code>
|
||||
@ -928,15 +1078,16 @@ def replace_code_tags(text):
|
||||
return ''.join(parts)
|
||||
|
||||
|
||||
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None) -> Tuple[int, int]:
|
||||
def find_line_number_of_relevant_line_in_file(
|
||||
diff_files: List[FilePatchInfo],
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None,
|
||||
) -> Tuple[int, int]:
|
||||
position = -1
|
||||
if absolute_position is None:
|
||||
absolute_position = -1
|
||||
re_hunk_header = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
re_hunk_header = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
|
||||
if not diff_files:
|
||||
return position, absolute_position
|
||||
@ -947,7 +1098,7 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
patch_lines = patch.splitlines()
|
||||
delta = 0
|
||||
start1, size1, start2, size2 = 0, 0, 0, 0
|
||||
if absolute_position != -1: # matching absolute to relative
|
||||
if absolute_position != -1: # matching absolute to relative
|
||||
for i, line in enumerate(patch_lines):
|
||||
# new hunk
|
||||
if line.startswith('@@'):
|
||||
@ -965,12 +1116,12 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
break
|
||||
else:
|
||||
# try to find the line in the patch using difflib, with some margin of error
|
||||
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
||||
patch_lines, n=3, cutoff=0.93)
|
||||
matches_difflib: list[str | Any] = difflib.get_close_matches(
|
||||
relevant_line_in_file, patch_lines, n=3, cutoff=0.93
|
||||
)
|
||||
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
||||
relevant_line_in_file = matches_difflib[0]
|
||||
|
||||
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
@ -1002,19 +1153,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
break
|
||||
return position, absolute_position
|
||||
|
||||
|
||||
def get_rate_limit_status(github_token) -> dict:
|
||||
GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
|
||||
GITHUB_API_URL = (
|
||||
get_settings(use_context=False)
|
||||
.get("GITHUB.BASE_URL", "https://api.github.com")
|
||||
.rstrip("/")
|
||||
) # "https://api.github.com"
|
||||
# GITHUB_API_URL = "https://api.github.com"
|
||||
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
|
||||
HEADERS = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"token {github_token}"
|
||||
"Authorization": f"token {github_token}",
|
||||
}
|
||||
|
||||
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
|
||||
try:
|
||||
rate_limit_info = response.json()
|
||||
if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise
|
||||
if (
|
||||
rate_limit_info.get('message') == 'Rate limiting is not enabled.'
|
||||
): # for github enterprise
|
||||
return {'resources': {}}
|
||||
response.raise_for_status() # Check for HTTP errors
|
||||
except: # retry
|
||||
@ -1024,12 +1182,16 @@ def get_rate_limit_status(github_token) -> dict:
|
||||
return rate_limit_info
|
||||
|
||||
|
||||
def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool:
|
||||
def validate_rate_limit_github(
|
||||
github_token, installation_id=None, threshold=0.1
|
||||
) -> bool:
|
||||
try:
|
||||
rate_limit_status = get_rate_limit_status(github_token)
|
||||
if installation_id:
|
||||
get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}")
|
||||
# validate that the rate limit is not exceeded
|
||||
get_logger().debug(
|
||||
f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}"
|
||||
)
|
||||
# validate that the rate limit is not exceeded
|
||||
# validate that the rate limit is not exceeded
|
||||
for key, value in rate_limit_status['resources'].items():
|
||||
if value['remaining'] < value['limit'] * threshold:
|
||||
@ -1037,8 +1199,9 @@ def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error in rate limit {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error in rate limit {e}", artifact={"traceback": traceback.format_exc()}
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@ -1051,7 +1214,9 @@ def validate_and_await_rate_limit(github_token):
|
||||
get_logger().error(f"key: {key}, value: {value}")
|
||||
sleep_time_sec = value['reset'] - datetime.now().timestamp()
|
||||
sleep_time_hour = sleep_time_sec / 3600.0
|
||||
get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
|
||||
get_logger().error(
|
||||
f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours"
|
||||
)
|
||||
if sleep_time_sec > 0:
|
||||
time.sleep(sleep_time_sec + 1)
|
||||
rate_limit_status = get_rate_limit_status(github_token)
|
||||
@ -1068,22 +1233,39 @@ def github_action_output(output_data: dict, key_name: str):
|
||||
|
||||
key_data = output_data.get(key_name, {})
|
||||
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
|
||||
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)
|
||||
print(
|
||||
f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}",
|
||||
file=fh,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to write to GitHub Action output: {e}")
|
||||
return
|
||||
|
||||
|
||||
def show_relevant_configurations(relevant_section: str) -> str:
|
||||
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
|
||||
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME']
|
||||
skip_keys = [
|
||||
'ai_disclaimer',
|
||||
'ai_disclaimer_title',
|
||||
'ANALYTICS_FOLDER',
|
||||
'secret_provider',
|
||||
"skip_keys",
|
||||
"app_id",
|
||||
"redirect",
|
||||
'trial_prefix_message',
|
||||
'no_eligible_message',
|
||||
'identity_provider',
|
||||
'ALLOWED_REPOS',
|
||||
'APP_NAME',
|
||||
]
|
||||
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
|
||||
if extra_skip_keys:
|
||||
skip_keys.extend(extra_skip_keys)
|
||||
|
||||
markdown_text = ""
|
||||
markdown_text += "\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n"
|
||||
markdown_text +="<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n"
|
||||
markdown_text += (
|
||||
"\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n"
|
||||
)
|
||||
markdown_text += "<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n"
|
||||
markdown_text += f"**[config**]\n```yaml\n\n"
|
||||
for key, value in get_settings().config.items():
|
||||
if key in skip_keys:
|
||||
@ -1099,6 +1281,7 @@ def show_relevant_configurations(relevant_section: str) -> str:
|
||||
markdown_text += "\n</details>\n"
|
||||
return markdown_text
|
||||
|
||||
|
||||
def is_value_no(value):
|
||||
if not value:
|
||||
return True
|
||||
@ -1122,7 +1305,7 @@ def string_to_uniform_number(s: str) -> float:
|
||||
# Convert the hash to an integer
|
||||
hash_int = int(hash_object.hexdigest(), 16)
|
||||
# Normalize the integer to the range [0, 1]
|
||||
max_hash_int = 2 ** 256 - 1
|
||||
max_hash_int = 2**256 - 1
|
||||
uniform_number = float(hash_int) / max_hash_int
|
||||
return uniform_number
|
||||
|
||||
@ -1131,7 +1314,9 @@ def process_description(description_full: str) -> Tuple[str, List]:
|
||||
if not description_full:
|
||||
return "", []
|
||||
|
||||
description_split = description_full.split(PRDescriptionHeader.CHANGES_WALKTHROUGH.value)
|
||||
description_split = description_full.split(
|
||||
PRDescriptionHeader.CHANGES_WALKTHROUGH.value
|
||||
)
|
||||
base_description_str = description_split[0]
|
||||
changes_walkthrough_str = ""
|
||||
files = []
|
||||
@ -1167,45 +1352,58 @@ def process_description(description_full: str) -> Tuple[str, List]:
|
||||
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong><dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\n\n\s*(.*?)</details>'
|
||||
res = re.search(pattern_back, file_data, re.DOTALL)
|
||||
if not res or res.lastindex != 4:
|
||||
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong>\s*<dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\s*-\s*(.*?)\s*</details>' # looking for hypen ('- ')
|
||||
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong>\s*<dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\s*-\s*(.*?)\s*</details>' # looking for hypen ('- ')
|
||||
res = re.search(pattern_back, file_data, re.DOTALL)
|
||||
if res and res.lastindex == 4:
|
||||
short_filename = res.group(1).strip()
|
||||
short_summary = res.group(2).strip()
|
||||
long_filename = res.group(3).strip()
|
||||
long_summary = res.group(4).strip()
|
||||
long_summary = long_summary.replace('<br> *', '\n*').replace('<br>','').replace('\n','<br>')
|
||||
long_summary = res.group(4).strip()
|
||||
long_summary = (
|
||||
long_summary.replace('<br> *', '\n*')
|
||||
.replace('<br>', '')
|
||||
.replace('\n', '<br>')
|
||||
)
|
||||
long_summary = h.handle(long_summary).strip()
|
||||
if long_summary.startswith('\\-'):
|
||||
long_summary = "* " + long_summary[2:]
|
||||
elif not long_summary.startswith('*'):
|
||||
long_summary = f"* {long_summary}"
|
||||
|
||||
files.append({
|
||||
'short_file_name': short_filename,
|
||||
'full_file_name': long_filename,
|
||||
'short_summary': short_summary,
|
||||
'long_summary': long_summary
|
||||
})
|
||||
files.append(
|
||||
{
|
||||
'short_file_name': short_filename,
|
||||
'full_file_name': long_filename,
|
||||
'short_summary': short_summary,
|
||||
'long_summary': long_summary,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if '<code>...</code>' in file_data:
|
||||
pass # PR with many files. some did not get analyzed
|
||||
pass # PR with many files. some did not get analyzed
|
||||
else:
|
||||
get_logger().error(f"Failed to parse description", artifact={'description': file_data})
|
||||
get_logger().error(
|
||||
f"Failed to parse description",
|
||||
artifact={'description': file_data},
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data})
|
||||
|
||||
get_logger().exception(
|
||||
f"Failed to process description: {e}",
|
||||
artifact={'description': file_data},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to process description: {e}")
|
||||
|
||||
return base_description_str, files
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
# First check pyproject.toml if running directly out of repository
|
||||
if os.path.exists("pyproject.toml"):
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
|
||||
with open("pyproject.toml", "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
if "project" in data and "version" in data["project"]:
|
||||
@ -1213,7 +1411,9 @@ def get_version() -> str:
|
||||
else:
|
||||
get_logger().warning("Version not found in pyproject.toml")
|
||||
else:
|
||||
get_logger().warning("Unable to determine local version from pyproject.toml")
|
||||
get_logger().warning(
|
||||
"Unable to determine local version from pyproject.toml"
|
||||
)
|
||||
|
||||
# Otherwise get the installed pip package version
|
||||
try:
|
||||
|
||||
@ -12,8 +12,9 @@ setup_logger(log_level)
|
||||
|
||||
|
||||
def set_parser():
|
||||
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
|
||||
"""\
|
||||
parser = argparse.ArgumentParser(
|
||||
description='AI based pull request analyzer',
|
||||
usage="""\
|
||||
Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>].
|
||||
For example:
|
||||
- cli.py --pr_url=... review
|
||||
@ -45,11 +46,20 @@ def set_parser():
|
||||
Configuration:
|
||||
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
|
||||
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
||||
""")
|
||||
parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}')
|
||||
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None)
|
||||
parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None)
|
||||
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--version', action='version', version=f'pr-agent {get_version()}'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pr_url', type=str, help='The URL of the PR to review', default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
'--issue_url', type=str, help='The URL of the Issue to review', default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
'command', type=str, help='The', choices=commands, default='review'
|
||||
)
|
||||
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
||||
return parser
|
||||
|
||||
@ -76,14 +86,24 @@ def run(inargs=None, args=None):
|
||||
|
||||
async def inner():
|
||||
if args.issue_url:
|
||||
result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest))
|
||||
result = await asyncio.create_task(
|
||||
PRAgent().handle_request(args.issue_url, [command] + args.rest)
|
||||
)
|
||||
else:
|
||||
result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest))
|
||||
result = await asyncio.create_task(
|
||||
PRAgent().handle_request(args.pr_url, [command] + args.rest)
|
||||
)
|
||||
|
||||
if get_settings().litellm.get("enable_callbacks", False):
|
||||
# There may be additional events on the event queue from the run above. If there are give them time to complete.
|
||||
get_logger().debug("Waiting for event queue to complete")
|
||||
await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()])
|
||||
await asyncio.wait(
|
||||
[
|
||||
task
|
||||
for task in asyncio.all_tasks()
|
||||
if task is not asyncio.current_task()
|
||||
]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -7,7 +7,9 @@ def main():
|
||||
provider = "github" # GitHub provider
|
||||
user_token = "..." # GitHub user token
|
||||
openai_key = "..." # OpenAI key
|
||||
pr_url = "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
|
||||
pr_url = (
|
||||
"..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
|
||||
)
|
||||
command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"')
|
||||
|
||||
# Setting the configurations
|
||||
|
||||
@ -11,26 +11,29 @@ current_dir = dirname(abspath(__file__))
|
||||
global_settings = Dynaconf(
|
||||
envvar_prefix=False,
|
||||
merge_enabled=True,
|
||||
settings_files=[join(current_dir, f) for f in [
|
||||
"settings/configuration.toml",
|
||||
"settings/ignore.toml",
|
||||
"settings/language_extensions.toml",
|
||||
"settings/pr_reviewer_prompts.toml",
|
||||
"settings/pr_questions_prompts.toml",
|
||||
"settings/pr_line_questions_prompts.toml",
|
||||
"settings/pr_description_prompts.toml",
|
||||
"settings/pr_code_suggestions_prompts.toml",
|
||||
"settings/pr_code_suggestions_reflect_prompts.toml",
|
||||
"settings/pr_sort_code_suggestions_prompts.toml",
|
||||
"settings/pr_information_from_user_prompts.toml",
|
||||
"settings/pr_update_changelog_prompts.toml",
|
||||
"settings/pr_custom_labels.toml",
|
||||
"settings/pr_add_docs.toml",
|
||||
"settings/custom_labels.toml",
|
||||
"settings/pr_help_prompts.toml",
|
||||
"settings/.secrets.toml",
|
||||
"settings_prod/.secrets.toml",
|
||||
]]
|
||||
settings_files=[
|
||||
join(current_dir, f)
|
||||
for f in [
|
||||
"settings/configuration.toml",
|
||||
"settings/ignore.toml",
|
||||
"settings/language_extensions.toml",
|
||||
"settings/pr_reviewer_prompts.toml",
|
||||
"settings/pr_questions_prompts.toml",
|
||||
"settings/pr_line_questions_prompts.toml",
|
||||
"settings/pr_description_prompts.toml",
|
||||
"settings/pr_code_suggestions_prompts.toml",
|
||||
"settings/pr_code_suggestions_reflect_prompts.toml",
|
||||
"settings/pr_sort_code_suggestions_prompts.toml",
|
||||
"settings/pr_information_from_user_prompts.toml",
|
||||
"settings/pr_update_changelog_prompts.toml",
|
||||
"settings/pr_custom_labels.toml",
|
||||
"settings/pr_add_docs.toml",
|
||||
"settings/custom_labels.toml",
|
||||
"settings/pr_help_prompts.toml",
|
||||
"settings/.secrets.toml",
|
||||
"settings_prod/.secrets.toml",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -3,8 +3,9 @@ from starlette_context import context
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
||||
from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
||||
from utils.pr_agent.git_providers.bitbucket_server_provider import \
|
||||
BitbucketServerProvider
|
||||
from utils.pr_agent.git_providers.bitbucket_server_provider import (
|
||||
BitbucketServerProvider,
|
||||
)
|
||||
from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
||||
from utils.pr_agent.git_providers.gerrit_provider import GerritProvider
|
||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||
@ -28,7 +29,9 @@ def get_git_provider():
|
||||
try:
|
||||
provider_id = get_settings().config.git_provider
|
||||
except AttributeError as e:
|
||||
raise ValueError("git_provider is a required attribute in the configuration file") from e
|
||||
raise ValueError(
|
||||
"git_provider is a required attribute in the configuration file"
|
||||
) from e
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
return _GIT_PROVIDERS[provider_id]
|
||||
|
||||
@ -6,25 +6,33 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
from ..algo.file_filter import filter_ignored
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import (PRDescriptionHeader, find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff)
|
||||
from ..algo.utils import (
|
||||
PRDescriptionHeader,
|
||||
find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff,
|
||||
)
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import GitProvider
|
||||
|
||||
AZURE_DEVOPS_AVAILABLE = True
|
||||
ADO_APP_CLIENT_DEFAULT_ID = "499b84ac-1321-427f-aa17-267ca6975798/.default"
|
||||
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000-1
|
||||
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000 - 1
|
||||
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
# noinspection PyUnresolvedReferences
|
||||
from azure.devops.connection import Connection
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from azure.devops.v7_1.git.models import (Comment, CommentThread,
|
||||
GitPullRequest,
|
||||
GitPullRequestIterationChanges,
|
||||
GitVersionDescriptor)
|
||||
from azure.devops.v7_1.git.models import (
|
||||
Comment,
|
||||
CommentThread,
|
||||
GitPullRequest,
|
||||
GitPullRequestIterationChanges,
|
||||
GitVersionDescriptor,
|
||||
)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from msrest.authentication import BasicAuthentication
|
||||
@ -33,9 +41,8 @@ except ImportError:
|
||||
|
||||
|
||||
class AzureDevopsProvider(GitProvider):
|
||||
|
||||
def __init__(
|
||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
||||
):
|
||||
if not AZURE_DEVOPS_AVAILABLE:
|
||||
raise ImportError(
|
||||
@ -67,13 +74,16 @@ class AzureDevopsProvider(GitProvider):
|
||||
|
||||
if not relevant_lines_start or relevant_lines_start == -1:
|
||||
get_logger().warning(
|
||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
|
||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end < relevant_lines_start:
|
||||
get_logger().warning(f"Failed to publish code suggestion, "
|
||||
f"relevant_lines_end is {relevant_lines_end} and "
|
||||
f"relevant_lines_start is {relevant_lines_start}")
|
||||
get_logger().warning(
|
||||
f"Failed to publish code suggestion, "
|
||||
f"relevant_lines_end is {relevant_lines_end} and "
|
||||
f"relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end > relevant_lines_start:
|
||||
@ -98,30 +108,32 @@ class AzureDevopsProvider(GitProvider):
|
||||
for post_parameters in post_parameters_list:
|
||||
try:
|
||||
comment = Comment(content=post_parameters["body"], comment_type=1)
|
||||
thread = CommentThread(comments=[comment],
|
||||
thread_context={
|
||||
"filePath": post_parameters["path"],
|
||||
"rightFileStart": {
|
||||
"line": post_parameters["start_line"],
|
||||
"offset": 1,
|
||||
},
|
||||
"rightFileEnd": {
|
||||
"line": post_parameters["line"],
|
||||
"offset": 1,
|
||||
},
|
||||
})
|
||||
thread = CommentThread(
|
||||
comments=[comment],
|
||||
thread_context={
|
||||
"filePath": post_parameters["path"],
|
||||
"rightFileStart": {
|
||||
"line": post_parameters["start_line"],
|
||||
"offset": 1,
|
||||
},
|
||||
"rightFileEnd": {
|
||||
"line": post_parameters["line"],
|
||||
"offset": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
self.azure_devops_client.create_thread(
|
||||
comment_thread=thread,
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num
|
||||
pull_request_id=self.pr_num,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Azure failed to publish code suggestion, error: {e}")
|
||||
get_logger().warning(
|
||||
f"Azure failed to publish code suggestion, error: {e}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def get_pr_description_full(self) -> str:
|
||||
return self.pr.description
|
||||
|
||||
@ -204,9 +216,9 @@ class AzureDevopsProvider(GitProvider):
|
||||
def get_files(self):
|
||||
files = []
|
||||
for i in self.azure_devops_client.get_pull_request_commits(
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
):
|
||||
changes_obj = self.azure_devops_client.get_changes(
|
||||
project=self.workspace_slug,
|
||||
@ -220,7 +232,6 @@ class AzureDevopsProvider(GitProvider):
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
try:
|
||||
|
||||
if self.diff_files:
|
||||
return self.diff_files
|
||||
|
||||
@ -231,18 +242,20 @@ class AzureDevopsProvider(GitProvider):
|
||||
iterations = self.azure_devops_client.get_pull_request_iterations(
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
project=self.workspace_slug
|
||||
project=self.workspace_slug,
|
||||
)
|
||||
changes = None
|
||||
if iterations:
|
||||
iteration_id = iterations[-1].id # Get the last iteration (most recent changes)
|
||||
iteration_id = iterations[
|
||||
-1
|
||||
].id # Get the last iteration (most recent changes)
|
||||
|
||||
# Get changes for the iteration
|
||||
changes = self.azure_devops_client.get_pull_request_iteration_changes(
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
iteration_id=iteration_id,
|
||||
project=self.workspace_slug
|
||||
project=self.workspace_slug,
|
||||
)
|
||||
diff_files = []
|
||||
diffs = []
|
||||
@ -253,7 +266,9 @@ class AzureDevopsProvider(GitProvider):
|
||||
path = item.get('path', None)
|
||||
if path:
|
||||
diffs.append(path)
|
||||
diff_types[path] = change.additional_properties.get('changeType', 'Unknown')
|
||||
diff_types[path] = change.additional_properties.get(
|
||||
'changeType', 'Unknown'
|
||||
)
|
||||
|
||||
# wrong implementation - gets all the files that were changed in any commit in the PR
|
||||
# commits = self.azure_devops_client.get_pull_request_commits(
|
||||
@ -284,9 +299,13 @@ class AzureDevopsProvider(GitProvider):
|
||||
diffs = filter_ignored(diffs_original, 'azure')
|
||||
if diffs_original != diffs:
|
||||
try:
|
||||
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
|
||||
{"files": diffs_original, # diffs is just a list of names
|
||||
"filtered_files": diffs})
|
||||
get_logger().info(
|
||||
f"Filtered out [ignore] files for pull request:",
|
||||
extra={
|
||||
"files": diffs_original, # diffs is just a list of names
|
||||
"filtered_files": diffs,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -311,7 +330,10 @@ class AzureDevopsProvider(GitProvider):
|
||||
|
||||
new_file_content_str = new_file_content_str.content
|
||||
except Exception as error:
|
||||
get_logger().error(f"Failed to retrieve new file content of {file} at version {version}", error=error)
|
||||
get_logger().error(
|
||||
f"Failed to retrieve new file content of {file} at version {version}",
|
||||
error=error,
|
||||
)
|
||||
# get_logger().error(
|
||||
# "Failed to retrieve new file content of %s at version %s. Error: %s",
|
||||
# file,
|
||||
@ -325,7 +347,9 @@ class AzureDevopsProvider(GitProvider):
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif diff_types[file] == "delete":
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif "rename" in diff_types[file]: # diff_type can be `rename` | `edit, rename`
|
||||
elif (
|
||||
"rename" in diff_types[file]
|
||||
): # diff_type can be `rename` | `edit, rename`
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
|
||||
version = GitVersionDescriptor(
|
||||
@ -345,17 +369,27 @@ class AzureDevopsProvider(GitProvider):
|
||||
)
|
||||
original_file_content_str = original_file_content_str.content
|
||||
except Exception as error:
|
||||
get_logger().error(f"Failed to retrieve original file content of {file} at version {version}", error=error)
|
||||
get_logger().error(
|
||||
f"Failed to retrieve original file content of {file} at version {version}",
|
||||
error=error,
|
||||
)
|
||||
original_file_content_str = ""
|
||||
|
||||
patch = load_large_diff(
|
||||
file, new_file_content_str, original_file_content_str, show_warning=False
|
||||
file,
|
||||
new_file_content_str,
|
||||
original_file_content_str,
|
||||
show_warning=False,
|
||||
).rstrip()
|
||||
|
||||
# count number of lines added and removed
|
||||
patch_lines = patch.splitlines(keepends=True)
|
||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
||||
num_plus_lines = len(
|
||||
[line for line in patch_lines if line.startswith('+')]
|
||||
)
|
||||
num_minus_lines = len(
|
||||
[line for line in patch_lines if line.startswith('-')]
|
||||
)
|
||||
|
||||
diff_files.append(
|
||||
FilePatchInfo(
|
||||
@ -376,27 +410,35 @@ class AzureDevopsProvider(GitProvider):
|
||||
get_logger().exception(f"Failed to get diff files, error: {e}")
|
||||
return []
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False, thread_context=None):
|
||||
def publish_comment(
|
||||
self, pr_comment: str, is_temporary: bool = False, thread_context=None
|
||||
):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
||||
get_logger().debug(
|
||||
f"Skipping publish_comment for temporary comment: {pr_comment}"
|
||||
)
|
||||
return None
|
||||
comment = Comment(content=pr_comment)
|
||||
thread = CommentThread(comments=[comment], thread_context=thread_context, status=5)
|
||||
thread = CommentThread(
|
||||
comments=[comment], thread_context=thread_context, status=5
|
||||
)
|
||||
thread_response = self.azure_devops_client.create_thread(
|
||||
comment_thread=thread,
|
||||
project=self.workspace_slug,
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
)
|
||||
response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id}
|
||||
response = {
|
||||
"thread_id": thread_response.id,
|
||||
"comment_id": thread_response.comments[0].id,
|
||||
}
|
||||
if is_temporary:
|
||||
self.temp_comments.append(response)
|
||||
return response
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
||||
|
||||
usage_guide_text='<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
|
||||
usage_guide_text = '<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
|
||||
ind = pr_body.find(usage_guide_text)
|
||||
if ind != -1:
|
||||
pr_body = pr_body[:ind]
|
||||
@ -409,7 +451,10 @@ class AzureDevopsProvider(GitProvider):
|
||||
|
||||
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
||||
trunction_message = " ... (description truncated due to length limit)"
|
||||
pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message
|
||||
pr_body = (
|
||||
pr_body[: MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)]
|
||||
+ trunction_message
|
||||
)
|
||||
get_logger().warning("PR description was truncated due to length limit")
|
||||
try:
|
||||
updated_pr = GitPullRequest()
|
||||
@ -433,50 +478,79 @@ class AzureDevopsProvider(GitProvider):
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove temp comments, error: {e}")
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
|
||||
def publish_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
original_suggestion=None,
|
||||
):
|
||||
self.publish_inline_comments(
|
||||
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
|
||||
)
|
||||
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position)
|
||||
def create_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None,
|
||||
):
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position,
|
||||
)
|
||||
if position == -1:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
get_logger().info(
|
||||
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||
)
|
||||
subject_type = "FILE"
|
||||
else:
|
||||
subject_type = "LINE"
|
||||
path = relevant_file.strip()
|
||||
return dict(body=body, path=path, position=position, absolute_position=absolute_position) if subject_type == "LINE" else {}
|
||||
return (
|
||||
dict(
|
||||
body=body,
|
||||
path=path,
|
||||
position=position,
|
||||
absolute_position=absolute_position,
|
||||
)
|
||||
if subject_type == "LINE"
|
||||
else {}
|
||||
)
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
|
||||
overall_success = True
|
||||
for comment in comments:
|
||||
try:
|
||||
self.publish_comment(comment["body"],
|
||||
thread_context={
|
||||
"filePath": comment["path"],
|
||||
"rightFileStart": {
|
||||
"line": comment["absolute_position"],
|
||||
"offset": comment["position"],
|
||||
},
|
||||
"rightFileEnd": {
|
||||
"line": comment["absolute_position"],
|
||||
"offset": comment["position"],
|
||||
},
|
||||
})
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(
|
||||
f"Published code suggestion on {self.pr_num} at {comment['path']}"
|
||||
)
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().error(f"Failed to publish code suggestion, error: {e}")
|
||||
overall_success = False
|
||||
return overall_success
|
||||
def publish_inline_comments(
|
||||
self, comments: list[dict], disable_fallback: bool = False
|
||||
):
|
||||
overall_success = True
|
||||
for comment in comments:
|
||||
try:
|
||||
self.publish_comment(
|
||||
comment["body"],
|
||||
thread_context={
|
||||
"filePath": comment["path"],
|
||||
"rightFileStart": {
|
||||
"line": comment["absolute_position"],
|
||||
"offset": comment["position"],
|
||||
},
|
||||
"rightFileEnd": {
|
||||
"line": comment["absolute_position"],
|
||||
"offset": comment["position"],
|
||||
},
|
||||
},
|
||||
)
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(
|
||||
f"Published code suggestion on {self.pr_num} at {comment['path']}"
|
||||
)
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().error(f"Failed to publish code suggestion, error: {e}")
|
||||
overall_success = False
|
||||
return overall_success
|
||||
|
||||
def get_title(self):
|
||||
return self.pr.title
|
||||
@ -521,7 +595,11 @@ class AzureDevopsProvider(GitProvider):
|
||||
return 0
|
||||
|
||||
def get_issue_comments(self):
|
||||
threads = self.azure_devops_client.get_threads(repository_id=self.repo_slug, pull_request_id=self.pr_num, project=self.workspace_slug)
|
||||
threads = self.azure_devops_client.get_threads(
|
||||
repository_id=self.repo_slug,
|
||||
pull_request_id=self.pr_num,
|
||||
project=self.workspace_slug,
|
||||
)
|
||||
threads.reverse()
|
||||
comment_list = []
|
||||
for thread in threads:
|
||||
@ -532,7 +610,9 @@ class AzureDevopsProvider(GitProvider):
|
||||
comment_list.append(comment)
|
||||
return comment_list
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
def add_eyes_reaction(
|
||||
self, issue_comment_id: int, disable_eyes: bool = False
|
||||
) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
@ -547,16 +627,22 @@ class AzureDevopsProvider(GitProvider):
|
||||
raise ValueError(
|
||||
"The provided URL does not appear to be a Azure DevOps PR URL"
|
||||
)
|
||||
if len(path_parts) == 6: # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
|
||||
if (
|
||||
len(path_parts) == 6
|
||||
): # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
|
||||
workspace_slug = path_parts[1]
|
||||
repo_slug = path_parts[3]
|
||||
pr_number = int(path_parts[5])
|
||||
elif len(path_parts) == 5: # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
|
||||
elif (
|
||||
len(path_parts) == 5
|
||||
): # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
|
||||
workspace_slug = path_parts[0]
|
||||
repo_slug = path_parts[2]
|
||||
pr_number = int(path_parts[4])
|
||||
else:
|
||||
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL")
|
||||
raise ValueError(
|
||||
"The provided URL does not appear to be a Azure DevOps PR URL"
|
||||
)
|
||||
|
||||
return workspace_slug, repo_slug, pr_number
|
||||
|
||||
@ -575,12 +661,16 @@ class AzureDevopsProvider(GitProvider):
|
||||
# try to use azure default credentials
|
||||
# see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
|
||||
# for usage and env var configuration of user-assigned managed identity, local machine auth etc.
|
||||
get_logger().info("No PAT found in settings, trying to use Azure Default Credentials.")
|
||||
get_logger().info(
|
||||
"No PAT found in settings, trying to use Azure Default Credentials."
|
||||
)
|
||||
credentials = DefaultAzureCredential()
|
||||
accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID)
|
||||
auth_token = accessToken.token
|
||||
except Exception as e:
|
||||
get_logger().error(f"No PAT found in settings, and Azure Default Authentication failed, error: {e}")
|
||||
get_logger().error(
|
||||
f"No PAT found in settings, and Azure Default Authentication failed, error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
credentials = BasicAuthentication("", auth_token)
|
||||
|
||||
@ -52,13 +52,19 @@ class BitbucketProvider(GitProvider):
|
||||
self.git_files = None
|
||||
if pr_url:
|
||||
self.set_pr(pr_url)
|
||||
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"]["comments"]["href"]
|
||||
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"]['self']['href']
|
||||
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"][
|
||||
"comments"
|
||||
]["href"]
|
||||
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"][
|
||||
'self'
|
||||
]['href']
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||
f"{self.pr.destination_branch}/.pr_agent.toml")
|
||||
url = (
|
||||
f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||
f"{self.pr.destination_branch}/.pr_agent.toml"
|
||||
)
|
||||
response = requests.request("GET", url, headers=self.headers)
|
||||
if response.status_code == 404: # not found
|
||||
return ""
|
||||
@ -74,20 +80,27 @@ class BitbucketProvider(GitProvider):
|
||||
post_parameters_list = []
|
||||
for suggestion in code_suggestions:
|
||||
body = suggestion["body"]
|
||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
||||
original_suggestion = suggestion.get(
|
||||
'original_suggestion', None
|
||||
) # needed for diff code
|
||||
if original_suggestion:
|
||||
try:
|
||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
||||
improved_code.split('\n'), n=999)
|
||||
diff = difflib.unified_diff(
|
||||
existing_code.split('\n'), improved_code.split('\n'), n=999
|
||||
)
|
||||
patch_orig = "\n".join(diff)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
||||
body = re.sub(
|
||||
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
|
||||
get_logger().exception(
|
||||
f"Bitbucket failed to get diff code for publishing, error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
relevant_file = suggestion["relevant_file"]
|
||||
@ -129,15 +142,22 @@ class BitbucketProvider(GitProvider):
|
||||
self.publish_inline_comments(post_parameters_list)
|
||||
return True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Bitbucket failed to publish code suggestion, error: {e}")
|
||||
get_logger().error(
|
||||
f"Bitbucket failed to publish code suggestion, error: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def publish_file_comments(self, file_comments: list) -> bool:
|
||||
pass
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'publish_inline_comments', 'get_labels', 'gfm_markdown',
|
||||
'publish_file_comments']:
|
||||
if capability in [
|
||||
'get_issue_comments',
|
||||
'publish_inline_comments',
|
||||
'get_labels',
|
||||
'gfm_markdown',
|
||||
'publish_file_comments',
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -169,12 +189,14 @@ class BitbucketProvider(GitProvider):
|
||||
names_original = [d.new.path for d in diffs_original]
|
||||
names_kept = [d.new.path for d in diffs]
|
||||
names_filtered = list(set(names_original) - set(names_kept))
|
||||
get_logger().info(f"Filtered out [ignore] files for PR", extra={
|
||||
'original_files': names_original,
|
||||
'names_kept': names_kept,
|
||||
'names_filtered': names_filtered
|
||||
|
||||
})
|
||||
get_logger().info(
|
||||
f"Filtered out [ignore] files for PR",
|
||||
extra={
|
||||
'original_files': names_original,
|
||||
'names_kept': names_kept,
|
||||
'names_filtered': names_filtered,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -189,20 +211,32 @@ class BitbucketProvider(GitProvider):
|
||||
for encoding in encodings_to_try:
|
||||
try:
|
||||
pr_patches = self.pr.diff(encoding=encoding)
|
||||
get_logger().info(f"Successfully decoded PR patch with encoding {encoding}")
|
||||
get_logger().info(
|
||||
f"Successfully decoded PR patch with encoding {encoding}"
|
||||
)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if pr_patches is None:
|
||||
raise ValueError(f"Failed to decode PR patch with encodings {encodings_to_try}")
|
||||
raise ValueError(
|
||||
f"Failed to decode PR patch with encodings {encodings_to_try}"
|
||||
)
|
||||
|
||||
diff_split = ["diff --git" + x for x in pr_patches.split("diff --git") if x.strip()]
|
||||
diff_split = [
|
||||
"diff --git" + x for x in pr_patches.split("diff --git") if x.strip()
|
||||
]
|
||||
# filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs'
|
||||
if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split):
|
||||
diff_split = [diff_split[i] for i in range(len(diff_split)) if diffs_original[i] in diffs]
|
||||
diff_split = [
|
||||
diff_split[i]
|
||||
for i in range(len(diff_split))
|
||||
if diffs_original[i] in diffs
|
||||
]
|
||||
if len(diff_split) != len(diffs):
|
||||
get_logger().error(f"Error - failed to split the diff into {len(diffs)} parts")
|
||||
get_logger().error(
|
||||
f"Error - failed to split the diff into {len(diffs)} parts"
|
||||
)
|
||||
return []
|
||||
# bitbucket diff has a header for each file, we need to remove it:
|
||||
# "diff --git filename
|
||||
@ -213,22 +247,34 @@ class BitbucketProvider(GitProvider):
|
||||
# @@ -... @@"
|
||||
for i, _ in enumerate(diff_split):
|
||||
diff_split_lines = diff_split[i].splitlines()
|
||||
if (len(diff_split_lines) >= 6) and \
|
||||
((diff_split_lines[2].startswith("---") and
|
||||
diff_split_lines[3].startswith("+++") and
|
||||
diff_split_lines[4].startswith("@@")) or
|
||||
(diff_split_lines[3].startswith("---") and # new or deleted file
|
||||
diff_split_lines[4].startswith("+++") and
|
||||
diff_split_lines[5].startswith("@@"))):
|
||||
if (len(diff_split_lines) >= 6) and (
|
||||
(
|
||||
diff_split_lines[2].startswith("---")
|
||||
and diff_split_lines[3].startswith("+++")
|
||||
and diff_split_lines[4].startswith("@@")
|
||||
)
|
||||
or (
|
||||
diff_split_lines[3].startswith("---")
|
||||
and diff_split_lines[4].startswith("+++") # new or deleted file
|
||||
and diff_split_lines[5].startswith("@@")
|
||||
)
|
||||
):
|
||||
diff_split[i] = "\n".join(diff_split_lines[4:])
|
||||
else:
|
||||
if diffs[i].data.get('lines_added', 0) == 0 and diffs[i].data.get('lines_removed', 0) == 0:
|
||||
if (
|
||||
diffs[i].data.get('lines_added', 0) == 0
|
||||
and diffs[i].data.get('lines_removed', 0) == 0
|
||||
):
|
||||
diff_split[i] = ""
|
||||
elif len(diff_split_lines) <= 3:
|
||||
diff_split[i] = ""
|
||||
get_logger().info(f"Disregarding empty diff for file {_gef_filename(diffs[i])}")
|
||||
get_logger().info(
|
||||
f"Disregarding empty diff for file {_gef_filename(diffs[i])}"
|
||||
)
|
||||
else:
|
||||
get_logger().warning(f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}")
|
||||
get_logger().warning(
|
||||
f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}"
|
||||
)
|
||||
diff_split[i] = ""
|
||||
|
||||
invalid_files_names = []
|
||||
@ -246,24 +292,32 @@ class BitbucketProvider(GitProvider):
|
||||
if get_settings().get("bitbucket_app.avoid_full_files", False):
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
elif counter_valid < MAX_FILES_ALLOWED_FULL // 2: # factor 2 because bitbucket has limited API calls
|
||||
elif (
|
||||
counter_valid < MAX_FILES_ALLOWED_FULL // 2
|
||||
): # factor 2 because bitbucket has limited API calls
|
||||
if diff.old.get_data("links"):
|
||||
original_file_content_str = self._get_pr_file_content(
|
||||
diff.old.get_data("links")['self']['href'])
|
||||
diff.old.get_data("links")['self']['href']
|
||||
)
|
||||
else:
|
||||
original_file_content_str = ""
|
||||
if diff.new.get_data("links"):
|
||||
new_file_content_str = self._get_pr_file_content(diff.new.get_data("links")['self']['href'])
|
||||
new_file_content_str = self._get_pr_file_content(
|
||||
diff.new.get_data("links")['self']['href']
|
||||
)
|
||||
else:
|
||||
new_file_content_str = ""
|
||||
else:
|
||||
if counter_valid == MAX_FILES_ALLOWED_FULL // 2:
|
||||
get_logger().info(
|
||||
f"Bitbucket too many files in PR, will avoid loading full content for rest of files")
|
||||
f"Bitbucket too many files in PR, will avoid loading full content for rest of files"
|
||||
)
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Error - bitbucket failed to get file content, error: {e}")
|
||||
get_logger().exception(
|
||||
f"Error - bitbucket failed to get file content, error: {e}"
|
||||
)
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
|
||||
@ -285,7 +339,9 @@ class BitbucketProvider(GitProvider):
|
||||
diff_files.append(file_patch_canonic_structure)
|
||||
|
||||
if invalid_files_names:
|
||||
get_logger().info(f"Disregarding files with invalid extensions:\n{invalid_files_names}")
|
||||
get_logger().info(
|
||||
f"Disregarding files with invalid extensions:\n{invalid_files_names}"
|
||||
)
|
||||
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
@ -296,11 +352,14 @@ class BitbucketProvider(GitProvider):
|
||||
def get_comment_url(self, comment):
|
||||
return comment.data['links']['html']['href']
|
||||
|
||||
def publish_persistent_comment(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
def publish_persistent_comment(
|
||||
self,
|
||||
pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True,
|
||||
):
|
||||
try:
|
||||
for comment in self.pr.comments():
|
||||
body = comment.raw
|
||||
@ -309,15 +368,20 @@ class BitbucketProvider(GitProvider):
|
||||
comment_url = self.get_comment_url(comment)
|
||||
if update_header:
|
||||
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
||||
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
|
||||
pr_comment_updated = pr_comment.replace(
|
||||
initial_header, updated_header
|
||||
)
|
||||
else:
|
||||
pr_comment_updated = pr_comment
|
||||
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
|
||||
get_logger().info(
|
||||
f"Persistent mode - updating comment {comment_url} to latest {name} message"
|
||||
)
|
||||
d = {"content": {"raw": pr_comment_updated}}
|
||||
response = comment._update_data(comment.put(None, data=d))
|
||||
if final_update_message:
|
||||
self.publish_comment(
|
||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
|
||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||
@ -326,7 +390,9 @@ class BitbucketProvider(GitProvider):
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
||||
get_logger().debug(
|
||||
f"Skipping publish_comment for temporary comment: {pr_comment}"
|
||||
)
|
||||
return None
|
||||
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length)
|
||||
comment = self.pr.comment(pr_comment)
|
||||
@ -355,39 +421,58 @@ class BitbucketProvider(GitProvider):
|
||||
get_logger().exception(f"Failed to remove comment, error: {e}")
|
||||
|
||||
# function to create_inline_comment
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
def create_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None,
|
||||
):
|
||||
body = self.limit_output_characters(body, self.max_comment_length)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position,
|
||||
)
|
||||
if position == -1:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
get_logger().info(
|
||||
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||
)
|
||||
subject_type = "FILE"
|
||||
else:
|
||||
subject_type = "LINE"
|
||||
path = relevant_file.strip()
|
||||
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
|
||||
return (
|
||||
dict(body=body, path=path, position=absolute_position)
|
||||
if subject_type == "LINE"
|
||||
else {}
|
||||
)
|
||||
|
||||
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
|
||||
def publish_inline_comment(
|
||||
self, comment: str, from_line: int, file: str, original_suggestion=None
|
||||
):
|
||||
comment = self.limit_output_characters(comment, self.max_comment_length)
|
||||
payload = json.dumps({
|
||||
"content": {
|
||||
"raw": comment,
|
||||
},
|
||||
"inline": {
|
||||
"to": from_line,
|
||||
"path": file
|
||||
},
|
||||
})
|
||||
payload = json.dumps(
|
||||
{
|
||||
"content": {
|
||||
"raw": comment,
|
||||
},
|
||||
"inline": {"to": from_line, "path": file},
|
||||
}
|
||||
)
|
||||
response = requests.request(
|
||||
"POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers
|
||||
)
|
||||
return response
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
def get_line_link(
|
||||
self,
|
||||
relevant_file: str,
|
||||
relevant_line_start: int,
|
||||
relevant_line_end: int = None,
|
||||
) -> str:
|
||||
if relevant_line_start == -1:
|
||||
link = f"{self.pr_url}/#L{relevant_file}"
|
||||
else:
|
||||
@ -402,8 +487,9 @@ class BitbucketProvider(GitProvider):
|
||||
return ""
|
||||
|
||||
diff_files = self.get_diff_files()
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
||||
(diff_files, relevant_file, relevant_line_str)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
diff_files, relevant_file, relevant_line_str
|
||||
)
|
||||
|
||||
if absolute_position != -1 and self.pr_url:
|
||||
link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}"
|
||||
@ -417,12 +503,18 @@ class BitbucketProvider(GitProvider):
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
for comment in comments:
|
||||
if 'position' in comment:
|
||||
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
|
||||
self.publish_inline_comment(
|
||||
comment['body'], comment['position'], comment['path']
|
||||
)
|
||||
elif 'start_line' in comment: # multi-line comment
|
||||
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
||||
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
|
||||
self.publish_inline_comment(
|
||||
comment['body'], comment['start_line'], comment['path']
|
||||
)
|
||||
elif 'line' in comment: # single-line comment
|
||||
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
|
||||
self.publish_inline_comment(
|
||||
comment['body'], comment['line'], comment['path']
|
||||
)
|
||||
else:
|
||||
get_logger().error(f"Could not publish inline comment {comment}")
|
||||
|
||||
@ -450,7 +542,9 @@ class BitbucketProvider(GitProvider):
|
||||
"Bitbucket provider does not support issue comments yet"
|
||||
)
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
def add_eyes_reaction(
|
||||
self, issue_comment_id: int, disable_eyes: bool = False
|
||||
) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
@ -495,8 +589,10 @@ class BitbucketProvider(GitProvider):
|
||||
branch = self.pr.data["source"]["commit"]["hash"]
|
||||
elif branch == self.pr.destination_branch:
|
||||
branch = self.pr.data["destination"]["commit"]["hash"]
|
||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||
f"{branch}/{file_path}")
|
||||
url = (
|
||||
f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||
f"{branch}/{file_path}"
|
||||
)
|
||||
response = requests.request("GET", url, headers=self.headers)
|
||||
if response.status_code == 404: # not found
|
||||
return ""
|
||||
@ -505,23 +601,28 @@ class BitbucketProvider(GitProvider):
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None:
|
||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/")
|
||||
def create_or_update_pr_file(
|
||||
self, file_path: str, branch: str, contents="", message=""
|
||||
) -> None:
|
||||
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||
if not message:
|
||||
if contents:
|
||||
message = f"Update {file_path}"
|
||||
else:
|
||||
message = f"Create {file_path}"
|
||||
files = {file_path: contents}
|
||||
data = {
|
||||
"message": message,
|
||||
"branch": branch
|
||||
}
|
||||
headers = {'Authorization': self.headers['Authorization']} if 'Authorization' in self.headers else {}
|
||||
data = {"message": message, "branch": branch}
|
||||
headers = (
|
||||
{'Authorization': self.headers['Authorization']}
|
||||
if 'Authorization' in self.headers
|
||||
else {}
|
||||
)
|
||||
try:
|
||||
requests.request("POST", url, headers=headers, data=data, files=files)
|
||||
except Exception:
|
||||
get_logger().exception(f"Failed to create empty file {file_path} in branch {branch}")
|
||||
get_logger().exception(
|
||||
f"Failed to create empty file {file_path} in branch {branch}"
|
||||
)
|
||||
|
||||
def _get_pr_file_content(self, remote_link: str):
|
||||
try:
|
||||
@ -538,16 +639,19 @@ class BitbucketProvider(GitProvider):
|
||||
|
||||
# bitbucket does not support labels
|
||||
def publish_description(self, pr_title: str, description: str):
|
||||
payload = json.dumps({
|
||||
"description": description,
|
||||
"title": pr_title
|
||||
payload = json.dumps({"description": description, "title": pr_title})
|
||||
|
||||
})
|
||||
|
||||
response = requests.request("PUT", self.bitbucket_pull_request_api_url, headers=self.headers, data=payload)
|
||||
response = requests.request(
|
||||
"PUT",
|
||||
self.bitbucket_pull_request_api_url,
|
||||
headers=self.headers,
|
||||
data=payload,
|
||||
)
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
get_logger().info(f"Failed to update description, error code: {response.status_code}")
|
||||
get_logger().info(
|
||||
f"Failed to update description, error code: {response.status_code}"
|
||||
)
|
||||
except:
|
||||
pass
|
||||
return response
|
||||
|
||||
@ -11,8 +11,7 @@ from requests.exceptions import HTTPError
|
||||
from ..algo.git_patch_processing import decode_if_bytes
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from ..algo.utils import (find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff)
|
||||
from ..algo.utils import find_line_number_of_relevant_line_in_file, load_large_diff
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import GitProvider
|
||||
@ -20,8 +19,10 @@ from .git_provider import GitProvider
|
||||
|
||||
class BitbucketServerProvider(GitProvider):
|
||||
def __init__(
|
||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False,
|
||||
bitbucket_client: Optional[Bitbucket] = None,
|
||||
self,
|
||||
pr_url: Optional[str] = None,
|
||||
incremental: Optional[bool] = False,
|
||||
bitbucket_client: Optional[Bitbucket] = None,
|
||||
):
|
||||
self.bitbucket_server_url = None
|
||||
self.workspace_slug = None
|
||||
@ -36,11 +37,16 @@ class BitbucketServerProvider(GitProvider):
|
||||
self.bitbucket_pull_request_api_url = pr_url
|
||||
|
||||
self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url)
|
||||
self.bitbucket_client = bitbucket_client or Bitbucket(url=self.bitbucket_server_url,
|
||||
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN",
|
||||
None))
|
||||
self.bitbucket_client = bitbucket_client or Bitbucket(
|
||||
url=self.bitbucket_server_url,
|
||||
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN", None),
|
||||
)
|
||||
try:
|
||||
self.bitbucket_api_version = parse_version(self.bitbucket_client.get("rest/api/1.0/application-properties").get('version'))
|
||||
self.bitbucket_api_version = parse_version(
|
||||
self.bitbucket_client.get("rest/api/1.0/application-properties").get(
|
||||
'version'
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
self.bitbucket_api_version = None
|
||||
|
||||
@ -49,7 +55,12 @@ class BitbucketServerProvider(GitProvider):
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch())
|
||||
content = self.bitbucket_client.get_content_of_file(
|
||||
self.workspace_slug,
|
||||
self.repo_slug,
|
||||
".pr_agent.toml",
|
||||
self.get_pr_branch(),
|
||||
)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
@ -70,20 +81,27 @@ class BitbucketServerProvider(GitProvider):
|
||||
post_parameters_list = []
|
||||
for suggestion in code_suggestions:
|
||||
body = suggestion["body"]
|
||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
||||
original_suggestion = suggestion.get(
|
||||
'original_suggestion', None
|
||||
) # needed for diff code
|
||||
if original_suggestion:
|
||||
try:
|
||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
||||
improved_code.split('\n'), n=999)
|
||||
diff = difflib.unified_diff(
|
||||
existing_code.split('\n'), improved_code.split('\n'), n=999
|
||||
)
|
||||
patch_orig = "\n".join(diff)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
||||
body = re.sub(
|
||||
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
|
||||
get_logger().exception(
|
||||
f"Bitbucket failed to get diff code for publishing, error: {e}"
|
||||
)
|
||||
continue
|
||||
relevant_file = suggestion["relevant_file"]
|
||||
relevant_lines_start = suggestion["relevant_lines_start"]
|
||||
@ -134,7 +152,12 @@ class BitbucketServerProvider(GitProvider):
|
||||
pass
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'get_labels', 'gfm_markdown', 'publish_file_comments']:
|
||||
if capability in [
|
||||
'get_issue_comments',
|
||||
'get_labels',
|
||||
'gfm_markdown',
|
||||
'publish_file_comments',
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -145,23 +168,28 @@ class BitbucketServerProvider(GitProvider):
|
||||
def get_file(self, path: str, commit_id: str):
|
||||
file_content = ""
|
||||
try:
|
||||
file_content = self.bitbucket_client.get_content_of_file(self.workspace_slug,
|
||||
self.repo_slug,
|
||||
path,
|
||||
commit_id)
|
||||
file_content = self.bitbucket_client.get_content_of_file(
|
||||
self.workspace_slug, self.repo_slug, path, commit_id
|
||||
)
|
||||
except HTTPError as e:
|
||||
get_logger().debug(f"File {path} not found at commit id: {commit_id}")
|
||||
return file_content
|
||||
|
||||
def get_files(self):
|
||||
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
|
||||
changes = self.bitbucket_client.get_pull_requests_changes(
|
||||
self.workspace_slug, self.repo_slug, self.pr_num
|
||||
)
|
||||
diffstat = [change["path"]['toString'] for change in changes]
|
||||
return diffstat
|
||||
|
||||
#gets the best common ancestor: https://git-scm.com/docs/git-merge-base
|
||||
# gets the best common ancestor: https://git-scm.com/docs/git-merge-base
|
||||
@staticmethod
|
||||
def get_best_common_ancestor(source_commits_list, destination_commits_list, guaranteed_common_ancestor) -> str:
|
||||
destination_commit_hashes = {commit['id'] for commit in destination_commits_list} | {guaranteed_common_ancestor}
|
||||
def get_best_common_ancestor(
|
||||
source_commits_list, destination_commits_list, guaranteed_common_ancestor
|
||||
) -> str:
|
||||
destination_commit_hashes = {
|
||||
commit['id'] for commit in destination_commits_list
|
||||
} | {guaranteed_common_ancestor}
|
||||
|
||||
for commit in source_commits_list:
|
||||
for parent_commit in commit['parents']:
|
||||
@ -177,37 +205,55 @@ class BitbucketServerProvider(GitProvider):
|
||||
head_sha = self.pr.fromRef['latestCommit']
|
||||
|
||||
# if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation
|
||||
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("8.16"):
|
||||
if (
|
||||
self.bitbucket_api_version is not None
|
||||
and self.bitbucket_api_version >= parse_version("8.16")
|
||||
):
|
||||
try:
|
||||
base_sha = self.bitbucket_client.get(self._get_merge_base())['id']
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}"
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
source_commits_list = list(self.bitbucket_client.get_pull_requests_commits(
|
||||
self.workspace_slug,
|
||||
self.repo_slug,
|
||||
self.pr_num
|
||||
))
|
||||
source_commits_list = list(
|
||||
self.bitbucket_client.get_pull_requests_commits(
|
||||
self.workspace_slug, self.repo_slug, self.pr_num
|
||||
)
|
||||
)
|
||||
# if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor
|
||||
base_sha = source_commits_list[-1]['parents'][0]['id']
|
||||
# if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha
|
||||
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("7.0"):
|
||||
if (
|
||||
self.bitbucket_api_version is not None
|
||||
and self.bitbucket_api_version >= parse_version("7.0")
|
||||
):
|
||||
try:
|
||||
destination_commits = list(
|
||||
self.bitbucket_client.get_commits(self.workspace_slug, self.repo_slug, base_sha,
|
||||
self.pr.toRef['latestCommit']))
|
||||
base_sha = self.get_best_common_ancestor(source_commits_list, destination_commits, base_sha)
|
||||
self.bitbucket_client.get_commits(
|
||||
self.workspace_slug,
|
||||
self.repo_slug,
|
||||
base_sha,
|
||||
self.pr.toRef['latestCommit'],
|
||||
)
|
||||
)
|
||||
base_sha = self.get_best_common_ancestor(
|
||||
source_commits_list, destination_commits, base_sha
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(
|
||||
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}")
|
||||
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
diff_files = []
|
||||
original_file_content_str = ""
|
||||
new_file_content_str = ""
|
||||
|
||||
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
|
||||
changes = self.bitbucket_client.get_pull_requests_changes(
|
||||
self.workspace_slug, self.repo_slug, self.pr_num
|
||||
)
|
||||
for change in changes:
|
||||
file_path = change['path']['toString']
|
||||
if not is_valid_file(file_path.split("/")[-1]):
|
||||
@ -224,17 +270,26 @@ class BitbucketServerProvider(GitProvider):
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
new_file_content_str = ""
|
||||
original_file_content_str = self.get_file(file_path, base_sha)
|
||||
original_file_content_str = decode_if_bytes(original_file_content_str)
|
||||
original_file_content_str = decode_if_bytes(
|
||||
original_file_content_str
|
||||
)
|
||||
case 'RENAME':
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
case _:
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
original_file_content_str = self.get_file(file_path, base_sha)
|
||||
original_file_content_str = decode_if_bytes(original_file_content_str)
|
||||
original_file_content_str = decode_if_bytes(
|
||||
original_file_content_str
|
||||
)
|
||||
new_file_content_str = self.get_file(file_path, head_sha)
|
||||
new_file_content_str = decode_if_bytes(new_file_content_str)
|
||||
|
||||
patch = load_large_diff(file_path, new_file_content_str, original_file_content_str, show_warning=False)
|
||||
patch = load_large_diff(
|
||||
file_path,
|
||||
new_file_content_str,
|
||||
original_file_content_str,
|
||||
show_warning=False,
|
||||
)
|
||||
|
||||
diff_files.append(
|
||||
FilePatchInfo(
|
||||
@ -251,7 +306,9 @@ class BitbucketServerProvider(GitProvider):
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if not is_temporary:
|
||||
self.bitbucket_client.add_pull_request_comment(self.workspace_slug, self.repo_slug, self.pr_num, pr_comment)
|
||||
self.bitbucket_client.add_pull_request_comment(
|
||||
self.workspace_slug, self.repo_slug, self.pr_num, pr_comment
|
||||
)
|
||||
|
||||
def remove_initial_comment(self):
|
||||
try:
|
||||
@ -264,25 +321,37 @@ class BitbucketServerProvider(GitProvider):
|
||||
pass
|
||||
|
||||
# function to create_inline_comment
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
|
||||
def create_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None,
|
||||
):
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
self.get_diff_files(),
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position
|
||||
absolute_position,
|
||||
)
|
||||
if position == -1:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
get_logger().info(
|
||||
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||
)
|
||||
subject_type = "FILE"
|
||||
else:
|
||||
subject_type = "LINE"
|
||||
path = relevant_file.strip()
|
||||
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
|
||||
return (
|
||||
dict(body=body, path=path, position=absolute_position)
|
||||
if subject_type == "LINE"
|
||||
else {}
|
||||
)
|
||||
|
||||
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
|
||||
def publish_inline_comment(
|
||||
self, comment: str, from_line: int, file: str, original_suggestion=None
|
||||
):
|
||||
payload = {
|
||||
"text": comment,
|
||||
"severity": "NORMAL",
|
||||
@ -291,17 +360,24 @@ class BitbucketServerProvider(GitProvider):
|
||||
"path": file,
|
||||
"lineType": "ADDED",
|
||||
"line": from_line,
|
||||
"fileType": "TO"
|
||||
}
|
||||
"fileType": "TO",
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
self.bitbucket_client.post(self._get_pr_comments_path(), data=payload)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
def get_line_link(
|
||||
self,
|
||||
relevant_file: str,
|
||||
relevant_line_start: int,
|
||||
relevant_line_end: int = None,
|
||||
) -> str:
|
||||
if relevant_line_start == -1:
|
||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}"
|
||||
else:
|
||||
@ -316,8 +392,9 @@ class BitbucketServerProvider(GitProvider):
|
||||
return ""
|
||||
|
||||
diff_files = self.get_diff_files()
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
||||
(diff_files, relevant_file, relevant_line_str)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
diff_files, relevant_file, relevant_line_str
|
||||
)
|
||||
|
||||
if absolute_position != -1:
|
||||
if self.pr:
|
||||
@ -325,29 +402,41 @@ class BitbucketServerProvider(GitProvider):
|
||||
return link
|
||||
else:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link to '{relevant_file}' since PR not set")
|
||||
get_logger().info(
|
||||
f"Failed adding line link to '{relevant_file}' since PR not set"
|
||||
)
|
||||
else:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link to '{relevant_file}' since position not found")
|
||||
get_logger().info(
|
||||
f"Failed adding line link to '{relevant_file}' since position not found"
|
||||
)
|
||||
|
||||
if absolute_position != -1 and self.pr_url:
|
||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
|
||||
return link
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Failed adding line link to '{relevant_file}', error: {e}")
|
||||
get_logger().info(
|
||||
f"Failed adding line link to '{relevant_file}', error: {e}"
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
for comment in comments:
|
||||
if 'position' in comment:
|
||||
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
|
||||
elif 'start_line' in comment: # multi-line comment
|
||||
self.publish_inline_comment(
|
||||
comment['body'], comment['position'], comment['path']
|
||||
)
|
||||
elif 'start_line' in comment: # multi-line comment
|
||||
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
||||
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
|
||||
elif 'line' in comment: # single-line comment
|
||||
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
|
||||
self.publish_inline_comment(
|
||||
comment['body'], comment['start_line'], comment['path']
|
||||
)
|
||||
elif 'line' in comment: # single-line comment
|
||||
self.publish_inline_comment(
|
||||
comment['body'], comment['line'], comment['path']
|
||||
)
|
||||
else:
|
||||
get_logger().error(f"Could not publish inline comment: {comment}")
|
||||
|
||||
@ -377,7 +466,9 @@ class BitbucketServerProvider(GitProvider):
|
||||
"Bitbucket provider does not support issue comments yet"
|
||||
)
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
def add_eyes_reaction(
|
||||
self, issue_comment_id: int, disable_eyes: bool = False
|
||||
) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
@ -411,14 +502,20 @@ class BitbucketServerProvider(GitProvider):
|
||||
users_index = -1
|
||||
|
||||
if projects_index == -1 and users_index == -1:
|
||||
raise ValueError(f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL")
|
||||
raise ValueError(
|
||||
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
|
||||
)
|
||||
|
||||
if projects_index != -1:
|
||||
path_parts = path_parts[projects_index:]
|
||||
else:
|
||||
path_parts = path_parts[users_index:]
|
||||
|
||||
if len(path_parts) < 6 or path_parts[2] != "repos" or path_parts[4] != "pull-requests":
|
||||
if (
|
||||
len(path_parts) < 6
|
||||
or path_parts[2] != "repos"
|
||||
or path_parts[4] != "pull-requests"
|
||||
):
|
||||
raise ValueError(
|
||||
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
|
||||
)
|
||||
@ -430,19 +527,24 @@ class BitbucketServerProvider(GitProvider):
|
||||
try:
|
||||
pr_number = int(path_parts[5])
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Unable to convert PR number '{path_parts[5]}' to integer") from e
|
||||
raise ValueError(
|
||||
f"Unable to convert PR number '{path_parts[5]}' to integer"
|
||||
) from e
|
||||
|
||||
return workspace_slug, repo_slug, pr_number
|
||||
|
||||
def _get_repo(self):
|
||||
if self.repo is None:
|
||||
self.repo = self.bitbucket_client.get_repo(self.workspace_slug, self.repo_slug)
|
||||
self.repo = self.bitbucket_client.get_repo(
|
||||
self.workspace_slug, self.repo_slug
|
||||
)
|
||||
return self.repo
|
||||
|
||||
def _get_pr(self):
|
||||
try:
|
||||
pr = self.bitbucket_client.get_pull_request(self.workspace_slug, self.repo_slug,
|
||||
pull_request_id=self.pr_num)
|
||||
pr = self.bitbucket_client.get_pull_request(
|
||||
self.workspace_slug, self.repo_slug, pull_request_id=self.pr_num
|
||||
)
|
||||
return type('new_dict', (object,), pr)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get pull request, error: {e}")
|
||||
@ -460,10 +562,12 @@ class BitbucketServerProvider(GitProvider):
|
||||
"version": self.pr.version,
|
||||
"description": description,
|
||||
"title": pr_title,
|
||||
"reviewers": self.pr.reviewers # needs to be sent otherwise gets wiped
|
||||
"reviewers": self.pr.reviewers, # needs to be sent otherwise gets wiped
|
||||
}
|
||||
try:
|
||||
self.bitbucket_client.update_pull_request(self.workspace_slug, self.repo_slug, str(self.pr_num), payload)
|
||||
self.bitbucket_client.update_pull_request(
|
||||
self.workspace_slug, self.repo_slug, str(self.pr_num), payload
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to update pull request, error: {e}")
|
||||
raise e
|
||||
|
||||
@ -31,7 +31,9 @@ class CodeCommitPullRequestResponse:
|
||||
|
||||
self.targets = []
|
||||
for target in json.get("pullRequestTargets", []):
|
||||
self.targets.append(CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target))
|
||||
self.targets.append(
|
||||
CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target)
|
||||
)
|
||||
|
||||
class CodeCommitPullRequestTarget:
|
||||
"""
|
||||
@ -65,7 +67,9 @@ class CodeCommitClient:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e
|
||||
|
||||
def get_differences(self, repo_name: int, destination_commit: str, source_commit: str):
|
||||
def get_differences(
|
||||
self, repo_name: int, destination_commit: str, source_commit: str
|
||||
):
|
||||
"""
|
||||
Get the differences between two commits in CodeCommit.
|
||||
|
||||
@ -96,17 +100,25 @@ class CodeCommitClient:
|
||||
differences.extend(page.get("differences", []))
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}") from e
|
||||
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}"
|
||||
) from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
|
||||
) from e
|
||||
|
||||
output = []
|
||||
for json in differences:
|
||||
output.append(CodeCommitDifferencesResponse(json))
|
||||
return output
|
||||
|
||||
def get_file(self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False):
|
||||
def get_file(
|
||||
self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False
|
||||
):
|
||||
"""
|
||||
Retrieve a file from CodeCommit.
|
||||
|
||||
@ -129,16 +141,24 @@ class CodeCommitClient:
|
||||
self._connect_boto_client()
|
||||
|
||||
try:
|
||||
response = self.boto_client.get_file(repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path)
|
||||
response = self.boto_client.get_file(
|
||||
repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path
|
||||
)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
|
||||
) from e
|
||||
# if the file does not exist, but is flagged as optional, then return an empty string
|
||||
if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException':
|
||||
return ""
|
||||
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
|
||||
) from e
|
||||
if "fileContent" not in response:
|
||||
raise ValueError(f"File content is empty for file: {file_path}")
|
||||
|
||||
@ -166,10 +186,16 @@ class CodeCommitClient:
|
||||
response = self.boto_client.get_pull_request(pullRequestId=str(pr_number))
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}"
|
||||
) from e
|
||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}: boto client error") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
|
||||
) from e
|
||||
raise ValueError(
|
||||
f"CodeCommit cannot retrieve PR: {pr_number}: boto client error"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e
|
||||
|
||||
@ -200,22 +226,37 @@ class CodeCommitClient:
|
||||
self._connect_boto_client()
|
||||
|
||||
try:
|
||||
self.boto_client.update_pull_request_title(pullRequestId=str(pr_number), title=pr_title)
|
||||
self.boto_client.update_pull_request_description(pullRequestId=str(pr_number), description=pr_body)
|
||||
self.boto_client.update_pull_request_title(
|
||||
pullRequestId=str(pr_number), title=pr_title
|
||||
)
|
||||
self.boto_client.update_pull_request_description(
|
||||
pullRequestId=str(pr_number), description=pr_body
|
||||
)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
||||
if e.response["Error"]["Code"] == 'InvalidTitleException':
|
||||
raise ValueError(f"Invalid title for PR number: {pr_number}") from e
|
||||
if e.response["Error"]["Code"] == 'InvalidDescriptionException':
|
||||
raise ValueError(f"Invalid description for PR number: {pr_number}") from e
|
||||
raise ValueError(
|
||||
f"Invalid description for PR number: {pr_number}"
|
||||
) from e
|
||||
if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException':
|
||||
raise ValueError(f"PR is already closed: PR number: {pr_number}") from e
|
||||
raise ValueError(f"Boto3 client error calling publish_description") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error calling publish_description") from e
|
||||
|
||||
def publish_comment(self, repo_name: str, pr_number: int, destination_commit: str, source_commit: str, comment: str, annotation_file: str = None, annotation_line: int = None):
|
||||
def publish_comment(
|
||||
self,
|
||||
repo_name: str,
|
||||
pr_number: int,
|
||||
destination_commit: str,
|
||||
source_commit: str,
|
||||
comment: str,
|
||||
annotation_file: str = None,
|
||||
annotation_line: int = None,
|
||||
):
|
||||
"""
|
||||
Publish a comment to a pull request
|
||||
|
||||
@ -272,6 +313,8 @@ class CodeCommitClient:
|
||||
raise ValueError(f"Repository does not exist: {repo_name}") from e
|
||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
||||
raise ValueError(f"Boto3 client error calling post_comment_for_pull_request") from e
|
||||
raise ValueError(
|
||||
f"Boto3 client error calling post_comment_for_pull_request"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error calling post_comment_for_pull_request") from e
|
||||
|
||||
@ -55,7 +55,9 @@ class CodeCommitProvider(GitProvider):
|
||||
This class implements the GitProvider interface for AWS CodeCommit repositories.
|
||||
"""
|
||||
|
||||
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||
def __init__(
|
||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
||||
):
|
||||
self.codecommit_client = CodeCommitClient()
|
||||
self.aws_client = None
|
||||
self.repo_name = None
|
||||
@ -76,7 +78,7 @@ class CodeCommitProvider(GitProvider):
|
||||
"create_inline_comment",
|
||||
"publish_inline_comments",
|
||||
"get_labels",
|
||||
"gfm_markdown"
|
||||
"gfm_markdown",
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
@ -91,13 +93,19 @@ class CodeCommitProvider(GitProvider):
|
||||
return self.git_files
|
||||
|
||||
self.git_files = []
|
||||
differences = self.codecommit_client.get_differences(self.repo_name, self.pr.destination_commit, self.pr.source_commit)
|
||||
differences = self.codecommit_client.get_differences(
|
||||
self.repo_name, self.pr.destination_commit, self.pr.source_commit
|
||||
)
|
||||
for item in differences:
|
||||
self.git_files.append(CodeCommitFile(item.before_blob_path,
|
||||
item.before_blob_id,
|
||||
item.after_blob_path,
|
||||
item.after_blob_id,
|
||||
CodeCommitProvider._get_edit_type(item.change_type)))
|
||||
self.git_files.append(
|
||||
CodeCommitFile(
|
||||
item.before_blob_path,
|
||||
item.before_blob_id,
|
||||
item.after_blob_path,
|
||||
item.after_blob_id,
|
||||
CodeCommitProvider._get_edit_type(item.change_type),
|
||||
)
|
||||
)
|
||||
return self.git_files
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
@ -121,21 +129,28 @@ class CodeCommitProvider(GitProvider):
|
||||
if diff_item.a_blob_id is not None:
|
||||
patch_filename = diff_item.a_path
|
||||
original_file_content_str = self.codecommit_client.get_file(
|
||||
self.repo_name, diff_item.a_path, self.pr.destination_commit)
|
||||
self.repo_name, diff_item.a_path, self.pr.destination_commit
|
||||
)
|
||||
if isinstance(original_file_content_str, (bytes, bytearray)):
|
||||
original_file_content_str = original_file_content_str.decode("utf-8")
|
||||
original_file_content_str = original_file_content_str.decode(
|
||||
"utf-8"
|
||||
)
|
||||
else:
|
||||
original_file_content_str = ""
|
||||
|
||||
if diff_item.b_blob_id is not None:
|
||||
patch_filename = diff_item.b_path
|
||||
new_file_content_str = self.codecommit_client.get_file(self.repo_name, diff_item.b_path, self.pr.source_commit)
|
||||
new_file_content_str = self.codecommit_client.get_file(
|
||||
self.repo_name, diff_item.b_path, self.pr.source_commit
|
||||
)
|
||||
if isinstance(new_file_content_str, (bytes, bytearray)):
|
||||
new_file_content_str = new_file_content_str.decode("utf-8")
|
||||
else:
|
||||
new_file_content_str = ""
|
||||
|
||||
patch = load_large_diff(patch_filename, new_file_content_str, original_file_content_str)
|
||||
patch = load_large_diff(
|
||||
patch_filename, new_file_content_str, original_file_content_str
|
||||
)
|
||||
|
||||
# Store the diffs as a list of FilePatchInfo objects
|
||||
info = FilePatchInfo(
|
||||
@ -164,7 +179,9 @@ class CodeCommitProvider(GitProvider):
|
||||
pr_body=CodeCommitProvider._add_additional_newlines(pr_body),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit Cannot publish description for PR: {self.pr_num}") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit Cannot publish description for PR: {self.pr_num}"
|
||||
) from e
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary:
|
||||
@ -183,19 +200,28 @@ class CodeCommitProvider(GitProvider):
|
||||
comment=pr_comment,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit Cannot publish comment for PR: {self.pr_num}") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit Cannot publish comment for PR: {self.pr_num}"
|
||||
) from e
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
counter = 1
|
||||
for suggestion in code_suggestions:
|
||||
# Verify that each suggestion has the required keys
|
||||
if not all(key in suggestion for key in ["body", "relevant_file", "relevant_lines_start"]):
|
||||
get_logger().warning(f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys")
|
||||
if not all(
|
||||
key in suggestion
|
||||
for key in ["body", "relevant_file", "relevant_lines_start"]
|
||||
):
|
||||
get_logger().warning(
|
||||
f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys"
|
||||
)
|
||||
continue
|
||||
|
||||
# Publish the code suggestion to CodeCommit
|
||||
try:
|
||||
get_logger().debug(f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}")
|
||||
get_logger().debug(
|
||||
f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}"
|
||||
)
|
||||
self.codecommit_client.publish_comment(
|
||||
repo_name=self.repo_name,
|
||||
pr_number=self.pr_num,
|
||||
@ -206,7 +232,9 @@ class CodeCommitProvider(GitProvider):
|
||||
annotation_line=suggestion["relevant_lines_start"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}") from e
|
||||
raise ValueError(
|
||||
f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}"
|
||||
) from e
|
||||
|
||||
counter += 1
|
||||
|
||||
@ -227,12 +255,22 @@ class CodeCommitProvider(GitProvider):
|
||||
def remove_comment(self, comment):
|
||||
return "" # not implemented yet
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
def publish_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
original_suggestion=None,
|
||||
):
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html
|
||||
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
|
||||
raise NotImplementedError(
|
||||
"CodeCommit provider does not support publishing inline comments yet"
|
||||
)
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
|
||||
raise NotImplementedError(
|
||||
"CodeCommit provider does not support publishing inline comments yet"
|
||||
)
|
||||
|
||||
def get_title(self):
|
||||
return self.pr.title
|
||||
@ -257,7 +295,7 @@ class CodeCommitProvider(GitProvider):
|
||||
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
|
||||
"""
|
||||
commit_files = self.get_files()
|
||||
filenames = [ item.filename for item in commit_files ]
|
||||
filenames = [item.filename for item in commit_files]
|
||||
extensions = CodeCommitProvider._get_file_extensions(filenames)
|
||||
|
||||
# Calculate the percentage of each file extension in the PR
|
||||
@ -270,7 +308,9 @@ class CodeCommitProvider(GitProvider):
|
||||
# We build that language->extension dictionary here in main_extensions_flat.
|
||||
main_extensions_flat = {}
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||
language_extension_map = {
|
||||
k.lower(): v for k, v in language_extension_map_org.items()
|
||||
}
|
||||
for language, extensions in language_extension_map.items():
|
||||
for ext in extensions:
|
||||
main_extensions_flat[ext] = language
|
||||
@ -292,14 +332,20 @@ class CodeCommitProvider(GitProvider):
|
||||
return -1 # not implemented yet
|
||||
|
||||
def get_issue_comments(self):
|
||||
raise NotImplementedError("CodeCommit provider does not support issue comments yet")
|
||||
raise NotImplementedError(
|
||||
"CodeCommit provider does not support issue comments yet"
|
||||
)
|
||||
|
||||
def get_repo_settings(self):
|
||||
# a local ".pr_agent.toml" settings file is optional
|
||||
settings_filename = ".pr_agent.toml"
|
||||
return self.codecommit_client.get_file(self.repo_name, settings_filename, self.pr.source_commit, optional=True)
|
||||
return self.codecommit_client.get_file(
|
||||
self.repo_name, settings_filename, self.pr.source_commit, optional=True
|
||||
)
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
def add_eyes_reaction(
|
||||
self, issue_comment_id: int, disable_eyes: bool = False
|
||||
) -> Optional[int]:
|
||||
get_logger().info("CodeCommit provider does not support eyes reaction yet")
|
||||
return True
|
||||
|
||||
@ -323,7 +369,9 @@ class CodeCommitProvider(GitProvider):
|
||||
parsed_url = urlparse(pr_url)
|
||||
|
||||
if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc):
|
||||
raise ValueError(f"The provided URL is not a valid CodeCommit URL: {pr_url}")
|
||||
raise ValueError(
|
||||
f"The provided URL is not a valid CodeCommit URL: {pr_url}"
|
||||
)
|
||||
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
|
||||
@ -334,14 +382,18 @@ class CodeCommitProvider(GitProvider):
|
||||
or path_parts[2] != "repositories"
|
||||
or path_parts[4] != "pull-requests"
|
||||
):
|
||||
raise ValueError(f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}")
|
||||
raise ValueError(
|
||||
f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}"
|
||||
)
|
||||
|
||||
repo_name = path_parts[3]
|
||||
|
||||
try:
|
||||
pr_number = int(path_parts[5])
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Unable to convert PR number to integer: '{path_parts[5]}'") from e
|
||||
raise ValueError(
|
||||
f"Unable to convert PR number to integer: '{path_parts[5]}'"
|
||||
) from e
|
||||
|
||||
return repo_name, pr_number
|
||||
|
||||
@ -359,7 +411,12 @@ class CodeCommitProvider(GitProvider):
|
||||
Returns:
|
||||
- bool: True if the hostname is valid, False otherwise.
|
||||
"""
|
||||
return re.match(r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname) is not None
|
||||
return (
|
||||
re.match(
|
||||
r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
def _get_pr(self):
|
||||
response = self.codecommit_client.get_pr(self.repo_name, self.pr_num)
|
||||
|
||||
@ -38,10 +38,7 @@ def clone(url, directory):
|
||||
|
||||
def fetch(url, refspec, cwd):
|
||||
get_logger().info("Fetching %s %s", url, refspec)
|
||||
stdout = _call(
|
||||
'git', 'fetch', '--depth', '2', url, refspec,
|
||||
cwd=cwd
|
||||
)
|
||||
stdout = _call('git', 'fetch', '--depth', '2', url, refspec, cwd=cwd)
|
||||
get_logger().info(stdout)
|
||||
|
||||
|
||||
@ -75,10 +72,13 @@ def add_comment(url: urllib3.util.Url, refspec, message):
|
||||
message = "'" + message.replace("'", "'\"'\"'") + "'"
|
||||
return _call(
|
||||
"ssh",
|
||||
"-p", str(url.port),
|
||||
"-p",
|
||||
str(url.port),
|
||||
f"{url.auth}@{url.host}",
|
||||
"gerrit", "review",
|
||||
"--message", message,
|
||||
"gerrit",
|
||||
"review",
|
||||
"--message",
|
||||
message,
|
||||
# "--code-review", score,
|
||||
f"{patchset},{changenum}",
|
||||
)
|
||||
@ -88,19 +88,23 @@ def list_comments(url: urllib3.util.Url, refspec):
|
||||
*_, patchset, _ = refspec.rsplit("/")
|
||||
stdout = _call(
|
||||
"ssh",
|
||||
"-p", str(url.port),
|
||||
"-p",
|
||||
str(url.port),
|
||||
f"{url.auth}@{url.host}",
|
||||
"gerrit", "query",
|
||||
"gerrit",
|
||||
"query",
|
||||
"--comments",
|
||||
"--current-patch-set", patchset,
|
||||
"--format", "JSON",
|
||||
"--current-patch-set",
|
||||
patchset,
|
||||
"--format",
|
||||
"JSON",
|
||||
)
|
||||
change_set, *_ = stdout.splitlines()
|
||||
return json.loads(change_set)["currentPatchSet"]["comments"]
|
||||
|
||||
|
||||
def prepare_repo(url: urllib3.util.Url, project, refspec):
|
||||
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}")
|
||||
repo_url = f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}"
|
||||
|
||||
directory = pathlib.Path(mkdtemp())
|
||||
clone(repo_url, directory),
|
||||
@ -114,18 +118,18 @@ def adopt_to_gerrit_message(message):
|
||||
buf = []
|
||||
for line in lines:
|
||||
# remove markdown formatting
|
||||
line = (line.replace("*", "")
|
||||
.replace("``", "`")
|
||||
.replace("<details>", "")
|
||||
.replace("</details>", "")
|
||||
.replace("<summary>", "")
|
||||
.replace("</summary>", ""))
|
||||
line = (
|
||||
line.replace("*", "")
|
||||
.replace("``", "`")
|
||||
.replace("<details>", "")
|
||||
.replace("</details>", "")
|
||||
.replace("<summary>", "")
|
||||
.replace("</summary>", "")
|
||||
)
|
||||
|
||||
line = line.strip()
|
||||
if line.startswith('#'):
|
||||
buf.append("\n" +
|
||||
line.replace('#', '').removesuffix(":").strip() +
|
||||
":")
|
||||
buf.append("\n" + line.replace('#', '').removesuffix(":").strip() + ":")
|
||||
continue
|
||||
elif line.startswith('-'):
|
||||
buf.append(line.removeprefix('-').strip())
|
||||
@ -136,12 +140,9 @@ def adopt_to_gerrit_message(message):
|
||||
|
||||
|
||||
def add_suggestion(src_filename, context: str, start, end: int):
|
||||
with (
|
||||
NamedTemporaryFile("w", delete=False) as tmp,
|
||||
open(src_filename, "r") as src
|
||||
):
|
||||
with NamedTemporaryFile("w", delete=False) as tmp, open(src_filename, "r") as src:
|
||||
lines = src.readlines()
|
||||
tmp.writelines(lines[:start - 1])
|
||||
tmp.writelines(lines[: start - 1])
|
||||
if context:
|
||||
tmp.write(context)
|
||||
tmp.writelines(lines[end:])
|
||||
@ -151,10 +152,8 @@ def add_suggestion(src_filename, context: str, start, end: int):
|
||||
|
||||
|
||||
def upload_patch(patch, path):
|
||||
patch_server_endpoint = get_settings().get(
|
||||
'gerrit.patch_server_endpoint')
|
||||
patch_server_token = get_settings().get(
|
||||
'gerrit.patch_server_token')
|
||||
patch_server_endpoint = get_settings().get('gerrit.patch_server_endpoint')
|
||||
patch_server_token = get_settings().get('gerrit.patch_server_token')
|
||||
|
||||
response = requests.post(
|
||||
patch_server_endpoint,
|
||||
@ -165,7 +164,7 @@ def upload_patch(patch, path):
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {patch_server_token}",
|
||||
}
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
patch_server_endpoint = patch_server_endpoint.rstrip("/")
|
||||
@ -173,7 +172,6 @@ def upload_patch(patch, path):
|
||||
|
||||
|
||||
class GerritProvider(GitProvider):
|
||||
|
||||
def __init__(self, key: str, incremental=False):
|
||||
self.project, self.refspec = key.split(':')
|
||||
assert self.project, "Project name is required"
|
||||
@ -188,9 +186,7 @@ class GerritProvider(GitProvider):
|
||||
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
|
||||
)
|
||||
|
||||
self.repo_path = prepare_repo(
|
||||
self.parsed_url, self.project, self.refspec
|
||||
)
|
||||
self.repo_path = prepare_repo(self.parsed_url, self.project, self.refspec)
|
||||
self.repo = Repo(self.repo_path)
|
||||
assert self.repo
|
||||
self.pr_url = base_url
|
||||
@ -210,15 +206,18 @@ class GerritProvider(GitProvider):
|
||||
|
||||
def get_pr_labels(self, update=False):
|
||||
raise NotImplementedError(
|
||||
'Getting labels is not implemented for the gerrit provider')
|
||||
'Getting labels is not implemented for the gerrit provider'
|
||||
)
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
|
||||
raise NotImplementedError(
|
||||
'Adding reactions is not implemented for the gerrit provider')
|
||||
'Adding reactions is not implemented for the gerrit provider'
|
||||
)
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
|
||||
raise NotImplementedError(
|
||||
'Removing reactions is not implemented for the gerrit provider')
|
||||
'Removing reactions is not implemented for the gerrit provider'
|
||||
)
|
||||
|
||||
def get_commit_messages(self):
|
||||
return [self.repo.head.commit.message]
|
||||
@ -235,20 +234,21 @@ class GerritProvider(GitProvider):
|
||||
diffs = self.repo.head.commit.diff(
|
||||
self.repo.head.commit.parents[0], # previous commit
|
||||
create_patch=True,
|
||||
R=True
|
||||
R=True,
|
||||
)
|
||||
|
||||
diff_files = []
|
||||
for diff_item in diffs:
|
||||
if diff_item.a_blob is not None:
|
||||
original_file_content_str = (
|
||||
diff_item.a_blob.data_stream.read().decode('utf-8')
|
||||
original_file_content_str = diff_item.a_blob.data_stream.read().decode(
|
||||
'utf-8'
|
||||
)
|
||||
else:
|
||||
original_file_content_str = "" # empty file
|
||||
if diff_item.b_blob is not None:
|
||||
new_file_content_str = diff_item.b_blob.data_stream.read(). \
|
||||
decode('utf-8')
|
||||
new_file_content_str = diff_item.b_blob.data_stream.read().decode(
|
||||
'utf-8'
|
||||
)
|
||||
else:
|
||||
new_file_content_str = "" # empty file
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
@ -267,7 +267,7 @@ class GerritProvider(GitProvider):
|
||||
edit_type=edit_type,
|
||||
old_filename=None
|
||||
if diff_item.a_path == diff_item.b_path
|
||||
else diff_item.a_path
|
||||
else diff_item.a_path,
|
||||
)
|
||||
)
|
||||
self.diff_files = diff_files
|
||||
@ -275,8 +275,7 @@ class GerritProvider(GitProvider):
|
||||
|
||||
def get_files(self):
|
||||
diff_index = self.repo.head.commit.diff(
|
||||
self.repo.head.commit.parents[0], # previous commit
|
||||
R=True
|
||||
self.repo.head.commit.parents[0], R=True # previous commit
|
||||
)
|
||||
# Get the list of changed files
|
||||
diff_files = [item.a_path for item in diff_index]
|
||||
@ -288,16 +287,22 @@ class GerritProvider(GitProvider):
|
||||
prioritisation.
|
||||
"""
|
||||
# Get all files in repository
|
||||
filepaths = [Path(item.path) for item in
|
||||
self.repo.tree().traverse() if item.type == 'blob']
|
||||
filepaths = [
|
||||
Path(item.path)
|
||||
for item in self.repo.tree().traverse()
|
||||
if item.type == 'blob'
|
||||
]
|
||||
# Identify language by file extension and count
|
||||
lang_count = Counter(
|
||||
ext.lstrip('.') for filepath in filepaths for ext in
|
||||
[filepath.suffix.lower()])
|
||||
ext.lstrip('.')
|
||||
for filepath in filepaths
|
||||
for ext in [filepath.suffix.lower()]
|
||||
)
|
||||
# Convert counts to percentages
|
||||
total_files = len(filepaths)
|
||||
lang_percentage = {lang: count / total_files * 100 for lang, count
|
||||
in lang_count.items()}
|
||||
lang_percentage = {
|
||||
lang: count / total_files * 100 for lang, count in lang_count.items()
|
||||
}
|
||||
return lang_percentage
|
||||
|
||||
def get_pr_description_full(self):
|
||||
@ -312,7 +317,7 @@ class GerritProvider(GitProvider):
|
||||
'create_inline_comment',
|
||||
'publish_inline_comments',
|
||||
'get_labels',
|
||||
'gfm_markdown'
|
||||
'gfm_markdown',
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
@ -331,14 +336,9 @@ class GerritProvider(GitProvider):
|
||||
if is_code_context:
|
||||
context.append(line)
|
||||
else:
|
||||
description.append(
|
||||
line.replace('*', '')
|
||||
)
|
||||
description.append(line.replace('*', ''))
|
||||
|
||||
return (
|
||||
'\n'.join(description),
|
||||
'\n'.join(context) + '\n' if context else ''
|
||||
)
|
||||
return ('\n'.join(description), '\n'.join(context) + '\n' if context else '')
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list):
|
||||
msg = []
|
||||
@ -372,15 +372,19 @@ class GerritProvider(GitProvider):
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the gerrit '
|
||||
'provider')
|
||||
'Publishing inline comments is not implemented for the gerrit ' 'provider'
|
||||
)
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str,
|
||||
relevant_line_in_file: str, original_suggestion=None):
|
||||
def publish_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
original_suggestion=None,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the gerrit '
|
||||
'provider')
|
||||
|
||||
'Publishing inline comments is not implemented for the gerrit ' 'provider'
|
||||
)
|
||||
|
||||
def publish_labels(self, labels):
|
||||
# Not applicable to the local git provider,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
||||
from typing import Optional
|
||||
|
||||
@ -9,6 +10,7 @@ from utils.pr_agent.log import get_logger
|
||||
|
||||
MAX_FILES_ALLOWED_FULL = 50
|
||||
|
||||
|
||||
class GitProvider(ABC):
|
||||
@abstractmethod
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
@ -61,11 +63,18 @@ class GitProvider(ABC):
|
||||
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
|
||||
pass
|
||||
|
||||
def get_pr_description(self, full: bool = True, split_changes_walkthrough=False) -> str or tuple:
|
||||
def get_pr_description(
|
||||
self, full: bool = True, split_changes_walkthrough=False
|
||||
) -> str or tuple:
|
||||
from utils.pr_agent.algo.utils import clip_tokens
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||
description = self.get_pr_description_full() if full else self.get_user_description()
|
||||
|
||||
max_tokens_description = get_settings().get(
|
||||
"CONFIG.MAX_DESCRIPTION_TOKENS", None
|
||||
)
|
||||
description = (
|
||||
self.get_pr_description_full() if full else self.get_user_description()
|
||||
)
|
||||
if split_changes_walkthrough:
|
||||
description, files = process_description(description)
|
||||
if max_tokens_description:
|
||||
@ -94,7 +103,9 @@ class GitProvider(ABC):
|
||||
# return nothing (empty string) because it means there is no user description
|
||||
user_description_header = "### **user description**"
|
||||
if user_description_header not in description_lowercase:
|
||||
get_logger().info(f"Existing description was generated by the pr-agent, but it doesn't contain a user description")
|
||||
get_logger().info(
|
||||
f"Existing description was generated by the pr-agent, but it doesn't contain a user description"
|
||||
)
|
||||
return ""
|
||||
|
||||
# otherwise, extract the original user description from the existing pr-agent description and return it
|
||||
@ -103,9 +114,11 @@ class GitProvider(ABC):
|
||||
|
||||
# the 'user description' is in the beginning. extract and return it
|
||||
possible_headers = self._possible_headers()
|
||||
start_position = description_lowercase.find(user_description_header) + len(user_description_header)
|
||||
start_position = description_lowercase.find(user_description_header) + len(
|
||||
user_description_header
|
||||
)
|
||||
end_position = len(description)
|
||||
for header in possible_headers: # try to clip at the next header
|
||||
for header in possible_headers: # try to clip at the next header
|
||||
if header != user_description_header and header in description_lowercase:
|
||||
end_position = min(end_position, description_lowercase.find(header))
|
||||
if end_position != len(description) and end_position > start_position:
|
||||
@ -115,20 +128,34 @@ class GitProvider(ABC):
|
||||
else:
|
||||
original_user_description = description.split("___")[0].strip()
|
||||
if original_user_description.lower().startswith(user_description_header):
|
||||
original_user_description = original_user_description[len(user_description_header):].strip()
|
||||
original_user_description = original_user_description[
|
||||
len(user_description_header) :
|
||||
].strip()
|
||||
|
||||
get_logger().info(f"Extracted user description from existing description",
|
||||
description=original_user_description)
|
||||
get_logger().info(
|
||||
f"Extracted user description from existing description",
|
||||
description=original_user_description,
|
||||
)
|
||||
self.user_description = original_user_description
|
||||
return original_user_description
|
||||
|
||||
def _possible_headers(self):
|
||||
return ("### **user description**", "### **pr type**", "### **pr description**", "### **pr labels**", "### **type**", "### **description**",
|
||||
"### **labels**", "### 🤖 generated by pr agent")
|
||||
return (
|
||||
"### **user description**",
|
||||
"### **pr type**",
|
||||
"### **pr description**",
|
||||
"### **pr labels**",
|
||||
"### **type**",
|
||||
"### **description**",
|
||||
"### **labels**",
|
||||
"### 🤖 generated by pr agent",
|
||||
)
|
||||
|
||||
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
|
||||
possible_headers = self._possible_headers()
|
||||
return any(description_lowercase.startswith(header) for header in possible_headers)
|
||||
return any(
|
||||
description_lowercase.startswith(header) for header in possible_headers
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_repo_settings(self):
|
||||
@ -140,10 +167,17 @@ class GitProvider(ABC):
|
||||
def get_pr_id(self):
|
||||
return ""
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
def get_line_link(
|
||||
self,
|
||||
relevant_file: str,
|
||||
relevant_line_start: int,
|
||||
relevant_line_end: int = None,
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str:
|
||||
def get_lines_link_original_file(
|
||||
self, filepath: str, component_range: Range
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
#### comments operations ####
|
||||
@ -151,18 +185,24 @@ class GitProvider(ABC):
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
pass
|
||||
|
||||
def publish_persistent_comment(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
def publish_persistent_comment(
|
||||
self,
|
||||
pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True,
|
||||
):
|
||||
self.publish_comment(pr_comment)
|
||||
|
||||
def publish_persistent_comment_full(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
def publish_persistent_comment_full(
|
||||
self,
|
||||
pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True,
|
||||
):
|
||||
try:
|
||||
prev_comments = list(self.get_issue_comments())
|
||||
for comment in prev_comments:
|
||||
@ -171,29 +211,46 @@ class GitProvider(ABC):
|
||||
comment_url = self.get_comment_url(comment)
|
||||
if update_header:
|
||||
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
||||
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
|
||||
pr_comment_updated = pr_comment.replace(
|
||||
initial_header, updated_header
|
||||
)
|
||||
else:
|
||||
pr_comment_updated = pr_comment
|
||||
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
|
||||
get_logger().info(
|
||||
f"Persistent mode - updating comment {comment_url} to latest {name} message"
|
||||
)
|
||||
# response = self.mr.notes.update(comment.id, {'body': pr_comment_updated})
|
||||
self.edit_comment(comment, pr_comment_updated)
|
||||
if final_update_message:
|
||||
self.publish_comment(
|
||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
|
||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||
pass
|
||||
self.publish_comment(pr_comment)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
def publish_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
original_suggestion=None,
|
||||
):
|
||||
pass
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
raise NotImplementedError("This git provider does not support creating inline comments yet")
|
||||
def create_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"This git provider does not support creating inline comments yet"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
@ -227,7 +284,9 @@ class GitProvider(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
def add_eyes_reaction(
|
||||
self, issue_comment_id: int, disable_eyes: bool = False
|
||||
) -> Optional[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -284,16 +343,23 @@ def get_main_pr_language(languages, files) -> str:
|
||||
if not file:
|
||||
continue
|
||||
if isinstance(file, str):
|
||||
file = FilePatchInfo(base_file=None, head_file=None, patch=None, filename=file)
|
||||
file = FilePatchInfo(
|
||||
base_file=None, head_file=None, patch=None, filename=file
|
||||
)
|
||||
extension_list.append(file.filename.rsplit('.')[-1])
|
||||
|
||||
# get the most common extension
|
||||
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
|
||||
try:
|
||||
language_extension_map_org = get_settings().language_extension_map_org
|
||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
||||
language_extension_map = {
|
||||
k.lower(): v for k, v in language_extension_map_org.items()
|
||||
}
|
||||
|
||||
if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]:
|
||||
if (
|
||||
top_language in language_extension_map
|
||||
and most_common_extension in language_extension_map[top_language]
|
||||
):
|
||||
main_language_str = top_language
|
||||
else:
|
||||
for language, extensions in language_extension_map.items():
|
||||
@ -332,8 +398,6 @@ def get_main_pr_language(languages, files) -> str:
|
||||
return main_language_str
|
||||
|
||||
|
||||
|
||||
|
||||
class IncrementalPR:
|
||||
def __init__(self, is_incremental: bool = False):
|
||||
self.is_incremental = is_incremental
|
||||
|
||||
@ -18,14 +18,23 @@ from ..algo.file_filter import filter_ignored
|
||||
from ..algo.git_patch_processing import extract_hunk_headers
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.types import EDIT_TYPE
|
||||
from ..algo.utils import (PRReviewHeader, Range, clip_tokens,
|
||||
find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff, set_file_languages)
|
||||
from ..algo.utils import (
|
||||
PRReviewHeader,
|
||||
Range,
|
||||
clip_tokens,
|
||||
find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff,
|
||||
set_file_languages,
|
||||
)
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from ..servers.utils import RateLimitExceeded
|
||||
from .git_provider import (MAX_FILES_ALLOWED_FULL, FilePatchInfo, GitProvider,
|
||||
IncrementalPR)
|
||||
from .git_provider import (
|
||||
MAX_FILES_ALLOWED_FULL,
|
||||
FilePatchInfo,
|
||||
GitProvider,
|
||||
IncrementalPR,
|
||||
)
|
||||
|
||||
|
||||
class GithubProvider(GitProvider):
|
||||
@ -36,8 +45,14 @@ class GithubProvider(GitProvider):
|
||||
except Exception:
|
||||
self.installation_id = None
|
||||
self.max_comment_chars = 65000
|
||||
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
|
||||
self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com"
|
||||
self.base_url = (
|
||||
get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/")
|
||||
) # "https://api.github.com"
|
||||
self.base_url_html = (
|
||||
self.base_url.split("api/")[0].rstrip("/")
|
||||
if "api/" in self.base_url
|
||||
else "https://github.com"
|
||||
)
|
||||
self.github_client = self._get_github_client()
|
||||
self.repo = None
|
||||
self.pr_num = None
|
||||
@ -50,7 +65,9 @@ class GithubProvider(GitProvider):
|
||||
self.set_pr(pr_url)
|
||||
self.pr_commits = list(self.pr.get_commits())
|
||||
self.last_commit_id = self.pr_commits[-1]
|
||||
self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
|
||||
self.pr_url = (
|
||||
self.get_pr_url()
|
||||
) # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
|
||||
else:
|
||||
self.pr_commits = None
|
||||
|
||||
@ -80,10 +97,14 @@ class GithubProvider(GitProvider):
|
||||
# Get all files changed during the commit range
|
||||
|
||||
for commit in self.incremental.commits_range:
|
||||
if commit.commit.message.startswith(f"Merge branch '{self._get_repo().default_branch}'"):
|
||||
if commit.commit.message.startswith(
|
||||
f"Merge branch '{self._get_repo().default_branch}'"
|
||||
):
|
||||
get_logger().info(f"Skipping merge commit {commit.commit.message}")
|
||||
continue
|
||||
self.unreviewed_files_set.update({file.filename: file for file in commit.files})
|
||||
self.unreviewed_files_set.update(
|
||||
{file.filename: file for file in commit.files}
|
||||
)
|
||||
else:
|
||||
get_logger().info("No previous review found, will review the entire PR")
|
||||
self.incremental.is_incremental = False
|
||||
@ -98,7 +119,11 @@ class GithubProvider(GitProvider):
|
||||
else:
|
||||
self.incremental.last_seen_commit = self.pr_commits[index]
|
||||
break
|
||||
return self.pr_commits[first_new_commit_index:] if first_new_commit_index is not None else []
|
||||
return (
|
||||
self.pr_commits[first_new_commit_index:]
|
||||
if first_new_commit_index is not None
|
||||
else []
|
||||
)
|
||||
|
||||
def get_previous_review(self, *, full: bool, incremental: bool):
|
||||
if not (full or incremental):
|
||||
@ -121,7 +146,7 @@ class GithubProvider(GitProvider):
|
||||
git_files = context.get("git_files", None)
|
||||
if git_files:
|
||||
return git_files
|
||||
self.git_files = list(self.pr.get_files()) # 'list' to handle pagination
|
||||
self.git_files = list(self.pr.get_files()) # 'list' to handle pagination
|
||||
context["git_files"] = self.git_files
|
||||
return self.git_files
|
||||
except Exception:
|
||||
@ -138,8 +163,13 @@ class GithubProvider(GitProvider):
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
@retry(exceptions=RateLimitExceeded,
|
||||
tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
|
||||
@retry(
|
||||
exceptions=RateLimitExceeded,
|
||||
tries=get_settings().github.ratelimit_retries,
|
||||
delay=2,
|
||||
backoff=2,
|
||||
jitter=(1, 3),
|
||||
)
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
"""
|
||||
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub,
|
||||
@ -167,9 +197,10 @@ class GithubProvider(GitProvider):
|
||||
try:
|
||||
names_original = [file.filename for file in files_original]
|
||||
names_new = [file.filename for file in files]
|
||||
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
|
||||
{"files": names_original,
|
||||
"filtered_files": names_new})
|
||||
get_logger().info(
|
||||
f"Filtered out [ignore] files for pull request:",
|
||||
extra={"files": names_original, "filtered_files": names_new},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -184,14 +215,17 @@ class GithubProvider(GitProvider):
|
||||
repo = self.repo_obj
|
||||
pr = self.pr
|
||||
try:
|
||||
compare = repo.compare(pr.base.sha, pr.head.sha) # communication with GitHub
|
||||
compare = repo.compare(
|
||||
pr.base.sha, pr.head.sha
|
||||
) # communication with GitHub
|
||||
merge_base_commit = compare.merge_base_commit
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get merge base commit: {e}")
|
||||
merge_base_commit = pr.base
|
||||
if merge_base_commit.sha != pr.base.sha:
|
||||
get_logger().info(
|
||||
f"Using merge base commit {merge_base_commit.sha} instead of base commit ")
|
||||
f"Using merge base commit {merge_base_commit.sha} instead of base commit "
|
||||
)
|
||||
|
||||
counter_valid = 0
|
||||
for file in files:
|
||||
@ -207,29 +241,48 @@ class GithubProvider(GitProvider):
|
||||
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
|
||||
counter_valid += 1
|
||||
avoid_load = False
|
||||
if counter_valid >= MAX_FILES_ALLOWED_FULL and patch and not self.incremental.is_incremental:
|
||||
if (
|
||||
counter_valid >= MAX_FILES_ALLOWED_FULL
|
||||
and patch
|
||||
and not self.incremental.is_incremental
|
||||
):
|
||||
avoid_load = True
|
||||
if counter_valid == MAX_FILES_ALLOWED_FULL:
|
||||
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
|
||||
get_logger().info(
|
||||
f"Too many files in PR, will avoid loading full content for rest of files"
|
||||
)
|
||||
|
||||
if avoid_load:
|
||||
new_file_content_str = ""
|
||||
else:
|
||||
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub
|
||||
new_file_content_str = self._get_pr_file_content(
|
||||
file, self.pr.head.sha
|
||||
) # communication with GitHub
|
||||
|
||||
if self.incremental.is_incremental and self.unreviewed_files_set:
|
||||
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
|
||||
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
|
||||
original_file_content_str = self._get_pr_file_content(
|
||||
file, self.incremental.last_seen_commit_sha
|
||||
)
|
||||
patch = load_large_diff(
|
||||
file.filename,
|
||||
new_file_content_str,
|
||||
original_file_content_str,
|
||||
)
|
||||
self.unreviewed_files_set[file.filename] = patch
|
||||
else:
|
||||
if avoid_load:
|
||||
original_file_content_str = ""
|
||||
else:
|
||||
original_file_content_str = self._get_pr_file_content(file, merge_base_commit.sha)
|
||||
original_file_content_str = self._get_pr_file_content(
|
||||
file, merge_base_commit.sha
|
||||
)
|
||||
# original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
|
||||
if not patch:
|
||||
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
|
||||
|
||||
patch = load_large_diff(
|
||||
file.filename,
|
||||
new_file_content_str,
|
||||
original_file_content_str,
|
||||
)
|
||||
|
||||
if file.status == 'added':
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
@ -249,16 +302,27 @@ class GithubProvider(GitProvider):
|
||||
num_minus_lines = file.deletions
|
||||
else:
|
||||
patch_lines = patch.splitlines(keepends=True)
|
||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
||||
num_plus_lines = len(
|
||||
[line for line in patch_lines if line.startswith('+')]
|
||||
)
|
||||
num_minus_lines = len(
|
||||
[line for line in patch_lines if line.startswith('-')]
|
||||
)
|
||||
|
||||
file_patch_canonical_structure = FilePatchInfo(original_file_content_str, new_file_content_str, patch,
|
||||
file.filename, edit_type=edit_type,
|
||||
num_plus_lines=num_plus_lines,
|
||||
num_minus_lines=num_minus_lines,)
|
||||
file_patch_canonical_structure = FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
patch,
|
||||
file.filename,
|
||||
edit_type=edit_type,
|
||||
num_plus_lines=num_plus_lines,
|
||||
num_minus_lines=num_minus_lines,
|
||||
)
|
||||
diff_files.append(file_patch_canonical_structure)
|
||||
if invalid_files_names:
|
||||
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
|
||||
get_logger().info(
|
||||
f"Filtered out files with invalid extensions: {invalid_files_names}"
|
||||
)
|
||||
|
||||
self.diff_files = diff_files
|
||||
try:
|
||||
@ -269,8 +333,10 @@ class GithubProvider(GitProvider):
|
||||
return diff_files
|
||||
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failing to get diff files: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Failing to get diff files: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
@ -282,16 +348,23 @@ class GithubProvider(GitProvider):
|
||||
def get_comment_url(self, comment) -> str:
|
||||
return comment.html_url
|
||||
|
||||
def publish_persistent_comment(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
|
||||
def publish_persistent_comment(
|
||||
self,
|
||||
pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True,
|
||||
):
|
||||
self.publish_persistent_comment_full(
|
||||
pr_comment, initial_header, update_header, name, final_update_message
|
||||
)
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
||||
get_logger().debug(
|
||||
f"Skipping publish_comment for temporary comment: {pr_comment}"
|
||||
)
|
||||
return None
|
||||
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars)
|
||||
response = self.pr.create_issue_comment(pr_comment)
|
||||
@ -303,42 +376,68 @@ class GithubProvider(GitProvider):
|
||||
self.pr.comments_list.append(response)
|
||||
return response
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
def publish_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
original_suggestion=None,
|
||||
):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
|
||||
self.publish_inline_comments(
|
||||
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
|
||||
)
|
||||
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
||||
absolute_position: int = None):
|
||||
def create_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None,
|
||||
):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.diff_files,
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
self.diff_files,
|
||||
relevant_file.strip('`'),
|
||||
relevant_line_in_file,
|
||||
absolute_position,
|
||||
)
|
||||
if position == -1:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
get_logger().info(
|
||||
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||
)
|
||||
subject_type = "FILE"
|
||||
else:
|
||||
subject_type = "LINE"
|
||||
path = relevant_file.strip()
|
||||
return dict(body=body, path=path, position=position) if subject_type == "LINE" else {}
|
||||
return (
|
||||
dict(body=body, path=path, position=position)
|
||||
if subject_type == "LINE"
|
||||
else {}
|
||||
)
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
|
||||
def publish_inline_comments(
|
||||
self, comments: list[dict], disable_fallback: bool = False
|
||||
):
|
||||
try:
|
||||
# publish all comments in a single message
|
||||
self.pr.create_review(commit=self.last_commit_id, comments=comments)
|
||||
except Exception as e:
|
||||
get_logger().info(f"Initially failed to publish inline comments as committable")
|
||||
get_logger().info(
|
||||
f"Initially failed to publish inline comments as committable"
|
||||
)
|
||||
|
||||
if (getattr(e, "status", None) == 422 and not disable_fallback):
|
||||
if getattr(e, "status", None) == 422 and not disable_fallback:
|
||||
pass # continue to try _publish_inline_comments_fallback_with_verification
|
||||
else:
|
||||
raise e # will end up with publishing the comments one by one
|
||||
raise e # will end up with publishing the comments one by one
|
||||
|
||||
try:
|
||||
self._publish_inline_comments_fallback_with_verification(comments)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to publish inline code comments fallback, error: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to publish inline code comments fallback, error: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]):
|
||||
@ -352,20 +451,27 @@ class GithubProvider(GitProvider):
|
||||
# publish as a group the verified comments
|
||||
if verified_comments:
|
||||
try:
|
||||
self.pr.create_review(commit=self.last_commit_id, comments=verified_comments)
|
||||
self.pr.create_review(
|
||||
commit=self.last_commit_id, comments=verified_comments
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# try to publish one by one the invalid comments as a one-line code comment
|
||||
if invalid_comments and get_settings().github.try_fix_invalid_inline_comments:
|
||||
fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments(
|
||||
[comment for comment, _ in invalid_comments])
|
||||
[comment for comment, _ in invalid_comments]
|
||||
)
|
||||
for comment in fixed_comments_as_one_liner:
|
||||
try:
|
||||
self.publish_inline_comments([comment], disable_fallback=True)
|
||||
get_logger().info(f"Published invalid comment as a single line comment: {comment}")
|
||||
get_logger().info(
|
||||
f"Published invalid comment as a single line comment: {comment}"
|
||||
)
|
||||
except:
|
||||
get_logger().error(f"Failed to publish invalid comment as a single line comment: {comment}")
|
||||
get_logger().error(
|
||||
f"Failed to publish invalid comment as a single line comment: {comment}"
|
||||
)
|
||||
|
||||
def _verify_code_comment(self, comment: dict):
|
||||
is_verified = False
|
||||
@ -374,7 +480,8 @@ class GithubProvider(GitProvider):
|
||||
# event ="" # By leaving this blank, you set the review action state to PENDING
|
||||
input = dict(commit_id=self.last_commit_id.sha, comments=[comment])
|
||||
headers, data = self.pr._requester.requestJsonAndCheck(
|
||||
"POST", f"{self.pr.url}/reviews", input=input)
|
||||
"POST", f"{self.pr.url}/reviews", input=input
|
||||
)
|
||||
pending_review_id = data["id"]
|
||||
is_verified = True
|
||||
except Exception as err:
|
||||
@ -383,12 +490,16 @@ class GithubProvider(GitProvider):
|
||||
e = err
|
||||
if pending_review_id is not None:
|
||||
try:
|
||||
self.pr._requester.requestJsonAndCheck("DELETE", f"{self.pr.url}/reviews/{pending_review_id}")
|
||||
self.pr._requester.requestJsonAndCheck(
|
||||
"DELETE", f"{self.pr.url}/reviews/{pending_review_id}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return is_verified, e
|
||||
|
||||
def _verify_code_comments(self, comments: list[dict]) -> tuple[list[dict], list[tuple[dict, Exception]]]:
|
||||
def _verify_code_comments(
|
||||
self, comments: list[dict]
|
||||
) -> tuple[list[dict], list[tuple[dict, Exception]]]:
|
||||
"""Very each comment against the GitHub API and return 2 lists: 1 of verified and 1 of invalid comments"""
|
||||
verified_comments = []
|
||||
invalid_comments = []
|
||||
@ -401,17 +512,22 @@ class GithubProvider(GitProvider):
|
||||
invalid_comments.append((comment, e))
|
||||
return verified_comments, invalid_comments
|
||||
|
||||
def _try_fix_invalid_inline_comments(self, invalid_comments: list[dict]) -> list[dict]:
|
||||
def _try_fix_invalid_inline_comments(
|
||||
self, invalid_comments: list[dict]
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Try fixing invalid comments by removing the suggestion part and setting the comment just on the first line.
|
||||
Return only comments that have been modified in some way.
|
||||
This is a best-effort attempt to fix invalid comments, and should be verified accordingly.
|
||||
"""
|
||||
import copy
|
||||
|
||||
fixed_comments = []
|
||||
for comment in invalid_comments:
|
||||
try:
|
||||
fixed_comment = copy.deepcopy(comment) # avoid modifying the original comment dict for later logging
|
||||
fixed_comment = copy.deepcopy(
|
||||
comment
|
||||
) # avoid modifying the original comment dict for later logging
|
||||
if "```suggestion" in comment["body"]:
|
||||
fixed_comment["body"] = comment["body"].split("```suggestion")[0]
|
||||
if "start_line" in comment:
|
||||
@ -432,7 +548,9 @@ class GithubProvider(GitProvider):
|
||||
"""
|
||||
post_parameters_list = []
|
||||
|
||||
code_suggestions_validated = self.validate_comments_inside_hunks(code_suggestions)
|
||||
code_suggestions_validated = self.validate_comments_inside_hunks(
|
||||
code_suggestions
|
||||
)
|
||||
|
||||
for suggestion in code_suggestions_validated:
|
||||
body = suggestion['body']
|
||||
@ -442,13 +560,16 @@ class GithubProvider(GitProvider):
|
||||
|
||||
if not relevant_lines_start or relevant_lines_start == -1:
|
||||
get_logger().exception(
|
||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
|
||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end < relevant_lines_start:
|
||||
get_logger().exception(f"Failed to publish code suggestion, "
|
||||
f"relevant_lines_end is {relevant_lines_end} and "
|
||||
f"relevant_lines_start is {relevant_lines_start}")
|
||||
get_logger().exception(
|
||||
f"Failed to publish code suggestion, "
|
||||
f"relevant_lines_end is {relevant_lines_end} and "
|
||||
f"relevant_lines_start is {relevant_lines_start}"
|
||||
)
|
||||
continue
|
||||
|
||||
if relevant_lines_end > relevant_lines_start:
|
||||
@ -484,17 +605,21 @@ class GithubProvider(GitProvider):
|
||||
# Log as warning for permission-related issues (usually due to polling)
|
||||
get_logger().warning(
|
||||
"Failed to edit github comment due to permission restrictions",
|
||||
artifact={"error": e})
|
||||
artifact={"error": e},
|
||||
)
|
||||
else:
|
||||
get_logger().exception(f"Failed to edit github comment", artifact={"error": e})
|
||||
get_logger().exception(
|
||||
f"Failed to edit github comment", artifact={"error": e}
|
||||
)
|
||||
|
||||
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
||||
try:
|
||||
# self.pr.get_issue_comment(comment_id).edit(body)
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||
"PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
|
||||
input={"body": body}
|
||||
"PATCH",
|
||||
f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
|
||||
input={"body": body},
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to edit comment, error: {e}")
|
||||
@ -504,8 +629,9 @@ class GithubProvider(GitProvider):
|
||||
# self.pr.get_issue_comment(comment_id).edit(body)
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||
"POST", f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies",
|
||||
input={"body": body}
|
||||
"POST",
|
||||
f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies",
|
||||
input={"body": body},
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to reply comment, error: {e}")
|
||||
@ -516,7 +642,7 @@ class GithubProvider(GitProvider):
|
||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||
"GET", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}"
|
||||
)
|
||||
return data_patch.get("body","")
|
||||
return data_patch.get("body", "")
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to edit comment, error: {e}")
|
||||
return None
|
||||
@ -528,7 +654,9 @@ class GithubProvider(GitProvider):
|
||||
)
|
||||
for comment in file_comments:
|
||||
comment['commit_id'] = self.last_commit_id.sha
|
||||
comment['body'] = self.limit_output_characters(comment['body'], self.max_comment_chars)
|
||||
comment['body'] = self.limit_output_characters(
|
||||
comment['body'], self.max_comment_chars
|
||||
)
|
||||
|
||||
found = False
|
||||
for existing_comment in existing_comments:
|
||||
@ -536,13 +664,23 @@ class GithubProvider(GitProvider):
|
||||
our_app_name = get_settings().get("GITHUB.APP_NAME", "")
|
||||
same_comment_creator = False
|
||||
if self.deployment_type == 'app':
|
||||
same_comment_creator = our_app_name.lower() in existing_comment['user']['login'].lower()
|
||||
same_comment_creator = (
|
||||
our_app_name.lower()
|
||||
in existing_comment['user']['login'].lower()
|
||||
)
|
||||
elif self.deployment_type == 'user':
|
||||
same_comment_creator = self.github_user_id == existing_comment['user']['login']
|
||||
if existing_comment['subject_type'] == 'file' and comment['path'] == existing_comment['path'] and same_comment_creator:
|
||||
|
||||
same_comment_creator = (
|
||||
self.github_user_id == existing_comment['user']['login']
|
||||
)
|
||||
if (
|
||||
existing_comment['subject_type'] == 'file'
|
||||
and comment['path'] == existing_comment['path']
|
||||
and same_comment_creator
|
||||
):
|
||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||
"PATCH", f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", input={"body":comment['body']}
|
||||
"PATCH",
|
||||
f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}",
|
||||
input={"body": comment['body']},
|
||||
)
|
||||
found = True
|
||||
break
|
||||
@ -600,7 +738,9 @@ class GithubProvider(GitProvider):
|
||||
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
|
||||
|
||||
if deployment_type != 'user':
|
||||
raise ValueError("Deployment mode must be set to 'user' to get notifications")
|
||||
raise ValueError(
|
||||
"Deployment mode must be set to 'user' to get notifications"
|
||||
)
|
||||
|
||||
notifications = self.github_client.get_user().get_notifications(since=since)
|
||||
return notifications
|
||||
@ -621,13 +761,16 @@ class GithubProvider(GitProvider):
|
||||
def get_workspace_name(self):
|
||||
return self.repo.split('/')[0]
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
def add_eyes_reaction(
|
||||
self, issue_comment_id: int, disable_eyes: bool = False
|
||||
) -> Optional[int]:
|
||||
if disable_eyes:
|
||||
return None
|
||||
try:
|
||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||
"POST", f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions",
|
||||
input={"content": "eyes"}
|
||||
"POST",
|
||||
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions",
|
||||
input={"content": "eyes"},
|
||||
)
|
||||
return data_patch.get("id", None)
|
||||
except Exception as e:
|
||||
@ -639,7 +782,7 @@ class GithubProvider(GitProvider):
|
||||
# self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id)
|
||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||
"DELETE",
|
||||
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}"
|
||||
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}",
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
@ -655,7 +798,9 @@ class GithubProvider(GitProvider):
|
||||
path_parts = parsed_url.path.strip('/').split('/')
|
||||
if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url:
|
||||
if len(path_parts) < 5 or path_parts[3] != 'pulls':
|
||||
raise ValueError("The provided URL does not appear to be a GitHub PR URL")
|
||||
raise ValueError(
|
||||
"The provided URL does not appear to be a GitHub PR URL"
|
||||
)
|
||||
repo_name = '/'.join(path_parts[1:3])
|
||||
try:
|
||||
pr_number = int(path_parts[4])
|
||||
@ -683,7 +828,9 @@ class GithubProvider(GitProvider):
|
||||
path_parts = parsed_url.path.strip('/').split('/')
|
||||
if 'api.github.com' in parsed_url.netloc:
|
||||
if len(path_parts) < 5 or path_parts[3] != 'issues':
|
||||
raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL")
|
||||
raise ValueError(
|
||||
"The provided URL does not appear to be a GitHub ISSUE URL"
|
||||
)
|
||||
repo_name = '/'.join(path_parts[1:3])
|
||||
try:
|
||||
issue_number = int(path_parts[4])
|
||||
@ -710,11 +857,18 @@ class GithubProvider(GitProvider):
|
||||
private_key = get_settings().github.private_key
|
||||
app_id = get_settings().github.app_id
|
||||
except AttributeError as e:
|
||||
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
|
||||
raise ValueError(
|
||||
"GitHub app ID and private key are required when using GitHub app deployment"
|
||||
) from e
|
||||
if not self.installation_id:
|
||||
raise ValueError("GitHub app installation ID is required when using GitHub app deployment")
|
||||
auth = AppAuthentication(app_id=app_id, private_key=private_key,
|
||||
installation_id=self.installation_id)
|
||||
raise ValueError(
|
||||
"GitHub app installation ID is required when using GitHub app deployment"
|
||||
)
|
||||
auth = AppAuthentication(
|
||||
app_id=app_id,
|
||||
private_key=private_key,
|
||||
installation_id=self.installation_id,
|
||||
)
|
||||
return Github(app_auth=auth, base_url=self.base_url)
|
||||
|
||||
if deployment_type == 'user':
|
||||
@ -723,19 +877,21 @@ class GithubProvider(GitProvider):
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
"GitHub token is required when using user deployment. See: "
|
||||
"https://github.com/Codium-ai/pr-agent#method-2-run-from-source") from e
|
||||
"https://github.com/Codium-ai/pr-agent#method-2-run-from-source"
|
||||
) from e
|
||||
return Github(auth=Auth.Token(token), base_url=self.base_url)
|
||||
|
||||
def _get_repo(self):
|
||||
if hasattr(self, 'repo_obj') and \
|
||||
hasattr(self.repo_obj, 'full_name') and \
|
||||
self.repo_obj.full_name == self.repo:
|
||||
if (
|
||||
hasattr(self, 'repo_obj')
|
||||
and hasattr(self.repo_obj, 'full_name')
|
||||
and self.repo_obj.full_name == self.repo
|
||||
):
|
||||
return self.repo_obj
|
||||
else:
|
||||
self.repo_obj = self.github_client.get_repo(self.repo)
|
||||
return self.repo_obj
|
||||
|
||||
|
||||
def _get_pr(self):
|
||||
return self._get_repo().get_pull(self.pr_num)
|
||||
|
||||
@ -755,9 +911,9 @@ class GithubProvider(GitProvider):
|
||||
) -> None:
|
||||
try:
|
||||
file_obj = self._get_repo().get_contents(file_path, ref=branch)
|
||||
sha1=file_obj.sha
|
||||
sha1 = file_obj.sha
|
||||
except Exception:
|
||||
sha1=""
|
||||
sha1 = ""
|
||||
self.repo_obj.update_file(
|
||||
path=file_path,
|
||||
message=message,
|
||||
@ -771,9 +927,14 @@ class GithubProvider(GitProvider):
|
||||
|
||||
def publish_labels(self, pr_types):
|
||||
try:
|
||||
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5",
|
||||
"Enhancement": "bfd4f2", "Documentation": "d4c5f9",
|
||||
"Other": "d1bcf9"}
|
||||
label_color_map = {
|
||||
"Bug fix": "1d76db",
|
||||
"Tests": "e99695",
|
||||
"Bug fix with tests": "c5def5",
|
||||
"Enhancement": "bfd4f2",
|
||||
"Documentation": "d4c5f9",
|
||||
"Other": "d1bcf9",
|
||||
}
|
||||
post_parameters = []
|
||||
for p in pr_types:
|
||||
color = label_color_map.get(p, "d1bcf9") # default to "Other" color
|
||||
@ -787,11 +948,12 @@ class GithubProvider(GitProvider):
|
||||
def get_pr_labels(self, update=False):
|
||||
try:
|
||||
if not update:
|
||||
labels =self.pr.labels
|
||||
labels = self.pr.labels
|
||||
return [label.name for label in labels]
|
||||
else: # obtain the latest labels. Maybe they changed while the AI was running
|
||||
else: # obtain the latest labels. Maybe they changed while the AI was running
|
||||
headers, labels = self.pr._requester.requestJsonAndCheck(
|
||||
"GET", f"{self.pr.issue_url}/labels")
|
||||
"GET", f"{self.pr.issue_url}/labels"
|
||||
)
|
||||
return [label['name'] for label in labels]
|
||||
|
||||
except Exception as e:
|
||||
@ -813,7 +975,9 @@ class GithubProvider(GitProvider):
|
||||
try:
|
||||
commit_list = self.pr.get_commits()
|
||||
commit_messages = [commit.commit.message for commit in commit_list]
|
||||
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
|
||||
commit_messages_str = "\n".join(
|
||||
[f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]
|
||||
)
|
||||
except Exception:
|
||||
commit_messages_str = ""
|
||||
if max_tokens:
|
||||
@ -822,13 +986,16 @@ class GithubProvider(GitProvider):
|
||||
|
||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||
try:
|
||||
relevant_file = suggestion['relevant_file'].strip('`').strip("'").strip('\n')
|
||||
relevant_file = (
|
||||
suggestion['relevant_file'].strip('`').strip("'").strip('\n')
|
||||
)
|
||||
relevant_line_str = suggestion['relevant_line'].strip('\n')
|
||||
if not relevant_line_str:
|
||||
return ""
|
||||
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
||||
(self.diff_files, relevant_file, relevant_line_str)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
self.diff_files, relevant_file, relevant_line_str
|
||||
)
|
||||
|
||||
if absolute_position != -1:
|
||||
# # link to right file only
|
||||
@ -844,7 +1011,12 @@ class GithubProvider(GitProvider):
|
||||
|
||||
return ""
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
def get_line_link(
|
||||
self,
|
||||
relevant_file: str,
|
||||
relevant_line_start: int,
|
||||
relevant_line_end: int = None,
|
||||
) -> str:
|
||||
sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest()
|
||||
if relevant_line_start == -1:
|
||||
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}"
|
||||
@ -854,7 +1026,9 @@ class GithubProvider(GitProvider):
|
||||
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}"
|
||||
return link
|
||||
|
||||
def get_lines_link_original_file(self, filepath: str, component_range: Range) -> str:
|
||||
def get_lines_link_original_file(
|
||||
self, filepath: str, component_range: Range
|
||||
) -> str:
|
||||
"""
|
||||
Returns the link to the original file on GitHub that corresponds to the given filepath and component range.
|
||||
|
||||
@ -876,8 +1050,10 @@ class GithubProvider(GitProvider):
|
||||
line_end = component_range.line_end + 1
|
||||
# link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
|
||||
# f"#L{line_start}-L{line_end}")
|
||||
link = (f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
|
||||
f"#L{line_start}-L{line_end}")
|
||||
link = (
|
||||
f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
|
||||
f"#L{line_start}-L{line_end}"
|
||||
)
|
||||
|
||||
return link
|
||||
|
||||
@ -909,8 +1085,9 @@ class GithubProvider(GitProvider):
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql",
|
||||
input={"query": query})
|
||||
response_tuple = self.github_client._Github__requester.requestJson(
|
||||
"POST", "/graphql", input={"query": query}
|
||||
)
|
||||
|
||||
# Extract the JSON response from the tuple and parses it
|
||||
if isinstance(response_tuple, tuple) and len(response_tuple) == 3:
|
||||
@ -919,8 +1096,12 @@ class GithubProvider(GitProvider):
|
||||
get_logger().error(f"Unexpected response format: {response_tuple}")
|
||||
return sub_issues
|
||||
|
||||
|
||||
issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id")
|
||||
issue_id = (
|
||||
response_json.get("data", {})
|
||||
.get("repository", {})
|
||||
.get("issue", {})
|
||||
.get("id")
|
||||
)
|
||||
|
||||
if not issue_id:
|
||||
get_logger().warning(f"Issue ID not found for {issue_url}")
|
||||
@ -940,22 +1121,42 @@ class GithubProvider(GitProvider):
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={
|
||||
"query": sub_issues_query})
|
||||
sub_issues_response_tuple = (
|
||||
self.github_client._Github__requester.requestJson(
|
||||
"POST", "/graphql", input={"query": sub_issues_query}
|
||||
)
|
||||
)
|
||||
|
||||
# Extract the JSON response from the tuple and parses it
|
||||
if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3:
|
||||
if (
|
||||
isinstance(sub_issues_response_tuple, tuple)
|
||||
and len(sub_issues_response_tuple) == 3
|
||||
):
|
||||
sub_issues_response_json = json.loads(sub_issues_response_tuple[2])
|
||||
else:
|
||||
get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple})
|
||||
get_logger().error(
|
||||
"Unexpected sub-issues response format",
|
||||
artifact={"response": sub_issues_response_tuple},
|
||||
)
|
||||
return sub_issues
|
||||
|
||||
if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"):
|
||||
if (
|
||||
not sub_issues_response_json.get("data", {})
|
||||
.get("node", {})
|
||||
.get("subIssues")
|
||||
):
|
||||
get_logger().error("Invalid sub-issues response structure")
|
||||
return sub_issues
|
||||
|
||||
nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", [])
|
||||
get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes})
|
||||
|
||||
nodes = (
|
||||
sub_issues_response_json.get("data", {})
|
||||
.get("node", {})
|
||||
.get("subIssues", {})
|
||||
.get("nodes", [])
|
||||
)
|
||||
get_logger().info(
|
||||
f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}
|
||||
)
|
||||
|
||||
for sub_issue in nodes:
|
||||
if "url" in sub_issue:
|
||||
@ -977,7 +1178,7 @@ class GithubProvider(GitProvider):
|
||||
return False
|
||||
|
||||
def calc_pr_statistics(self, pull_request_data: dict):
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def validate_comments_inside_hunks(self, code_suggestions):
|
||||
"""
|
||||
@ -986,7 +1187,8 @@ class GithubProvider(GitProvider):
|
||||
code_suggestions_copy = copy.deepcopy(code_suggestions)
|
||||
diff_files = self.get_diff_files()
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
|
||||
)
|
||||
|
||||
diff_files = set_file_languages(diff_files)
|
||||
|
||||
@ -995,7 +1197,6 @@ class GithubProvider(GitProvider):
|
||||
relevant_file_path = suggestion['relevant_file']
|
||||
for file in diff_files:
|
||||
if file.filename == relevant_file_path:
|
||||
|
||||
# generate on-demand the patches range for the relevant file
|
||||
patch_str = file.patch
|
||||
if not hasattr(file, 'patches_range'):
|
||||
@ -1006,14 +1207,30 @@ class GithubProvider(GitProvider):
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
# identify hunk header
|
||||
if match:
|
||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
||||
file.patches_range.append({'start': start2, 'end': start2 + size2 - 1})
|
||||
(
|
||||
section_header,
|
||||
size1,
|
||||
size2,
|
||||
start1,
|
||||
start2,
|
||||
) = extract_hunk_headers(match)
|
||||
file.patches_range.append(
|
||||
{'start': start2, 'end': start2 + size2 - 1}
|
||||
)
|
||||
|
||||
patches_range = file.patches_range
|
||||
comment_start_line = suggestion.get('relevant_lines_start', None)
|
||||
comment_start_line = suggestion.get(
|
||||
'relevant_lines_start', None
|
||||
)
|
||||
comment_end_line = suggestion.get('relevant_lines_end', None)
|
||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
||||
if not comment_start_line or not comment_end_line or not original_suggestion:
|
||||
original_suggestion = suggestion.get(
|
||||
'original_suggestion', None
|
||||
) # needed for diff code
|
||||
if (
|
||||
not comment_start_line
|
||||
or not comment_end_line
|
||||
or not original_suggestion
|
||||
):
|
||||
continue
|
||||
|
||||
# check if the comment is inside a valid hunk
|
||||
@ -1037,30 +1254,57 @@ class GithubProvider(GitProvider):
|
||||
patch_range_min = patch_range
|
||||
min_distance = min(min_distance, d)
|
||||
if not is_valid_hunk:
|
||||
if min_distance < 10: # 10 lines - a reasonable distance to consider the comment inside the hunk
|
||||
if (
|
||||
min_distance < 10
|
||||
): # 10 lines - a reasonable distance to consider the comment inside the hunk
|
||||
# make the suggestion non-committable, yet multi line
|
||||
suggestion['relevant_lines_start'] = max(suggestion['relevant_lines_start'], patch_range_min['start'])
|
||||
suggestion['relevant_lines_end'] = min(suggestion['relevant_lines_end'], patch_range_min['end'])
|
||||
suggestion['relevant_lines_start'] = max(
|
||||
suggestion['relevant_lines_start'],
|
||||
patch_range_min['start'],
|
||||
)
|
||||
suggestion['relevant_lines_end'] = min(
|
||||
suggestion['relevant_lines_end'],
|
||||
patch_range_min['end'],
|
||||
)
|
||||
body = suggestion['body'].strip()
|
||||
|
||||
# present new diff code in collapsible
|
||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
||||
improved_code.split('\n'), n=999)
|
||||
existing_code = (
|
||||
original_suggestion['existing_code'].rstrip() + "\n"
|
||||
)
|
||||
improved_code = (
|
||||
original_suggestion['improved_code'].rstrip() + "\n"
|
||||
)
|
||||
diff = difflib.unified_diff(
|
||||
existing_code.split('\n'),
|
||||
improved_code.split('\n'),
|
||||
n=999,
|
||||
)
|
||||
patch_orig = "\n".join(diff)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip(
|
||||
'\n'
|
||||
)
|
||||
diff_code = f"\n\n<details><summary>新提议的代码:</summary>\n\n```diff\n{patch.rstrip()}\n```"
|
||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
||||
body = re.sub(
|
||||
r'```suggestion.*?```',
|
||||
diff_code,
|
||||
body,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
body += "\n\n</details>"
|
||||
suggestion['body'] = body
|
||||
get_logger().info(f"Comment was moved to a valid hunk, "
|
||||
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}")
|
||||
get_logger().info(
|
||||
f"Comment was moved to a valid hunk, "
|
||||
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
|
||||
)
|
||||
else:
|
||||
get_logger().error(f"Comment is not inside a valid hunk, "
|
||||
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}")
|
||||
get_logger().error(
|
||||
f"Comment is not inside a valid hunk, "
|
||||
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to process patch for committable comment, error: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to process patch for committable comment, error: {e}"
|
||||
)
|
||||
return code_suggestions_copy
|
||||
|
||||
|
||||
@ -10,9 +10,11 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
from ..algo.file_filter import filter_ignored
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import (clip_tokens,
|
||||
find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff)
|
||||
from ..algo.utils import (
|
||||
clip_tokens,
|
||||
find_line_number_of_relevant_line_in_file,
|
||||
load_large_diff,
|
||||
)
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
|
||||
@ -20,22 +22,26 @@ from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
|
||||
|
||||
class DiffNotFoundError(Exception):
|
||||
"""Raised when the diff for a merge request cannot be found."""
|
||||
|
||||
pass
|
||||
|
||||
class GitLabProvider(GitProvider):
|
||||
|
||||
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||
class GitLabProvider(GitProvider):
|
||||
def __init__(
|
||||
self,
|
||||
merge_request_url: Optional[str] = None,
|
||||
incremental: Optional[bool] = False,
|
||||
):
|
||||
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||
if not gitlab_url:
|
||||
raise ValueError("GitLab URL is not set in the config file")
|
||||
self.gitlab_url = gitlab_url
|
||||
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||
if not gitlab_access_token:
|
||||
raise ValueError("GitLab personal access token is not set in the config file")
|
||||
self.gl = gitlab.Gitlab(
|
||||
url=gitlab_url,
|
||||
oauth_token=gitlab_access_token
|
||||
)
|
||||
raise ValueError(
|
||||
"GitLab personal access token is not set in the config file"
|
||||
)
|
||||
self.gl = gitlab.Gitlab(url=gitlab_url, oauth_token=gitlab_access_token)
|
||||
self.max_comment_chars = 65000
|
||||
self.id_project = None
|
||||
self.id_mr = None
|
||||
@ -46,12 +52,17 @@ class GitLabProvider(GitProvider):
|
||||
self.pr_url = merge_request_url
|
||||
self._set_merge_request(merge_request_url)
|
||||
self.RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
|
||||
)
|
||||
self.incremental = incremental
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments',
|
||||
'publish_file_comments']: # gfm_markdown is supported in gitlab !
|
||||
if capability in [
|
||||
'get_issue_comments',
|
||||
'create_inline_comment',
|
||||
'publish_inline_comments',
|
||||
'publish_file_comments',
|
||||
]: # gfm_markdown is supported in gitlab !
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -67,12 +78,17 @@ class GitLabProvider(GitProvider):
|
||||
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
|
||||
except IndexError as e:
|
||||
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
||||
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") from e
|
||||
|
||||
raise DiffNotFoundError(
|
||||
f"Could not get diff for merge request {self.id_mr}"
|
||||
) from e
|
||||
|
||||
def get_pr_file_content(self, file_path: str, branch: str) -> str:
|
||||
try:
|
||||
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
|
||||
return (
|
||||
self.gl.projects.get(self.id_project)
|
||||
.files.get(file_path, branch)
|
||||
.decode()
|
||||
)
|
||||
except GitlabGetError:
|
||||
# In case of file creation the method returns GitlabGetError (404 file not found).
|
||||
# In this case we return an empty string for the diff.
|
||||
@ -98,10 +114,13 @@ class GitLabProvider(GitProvider):
|
||||
try:
|
||||
names_original = [diff['new_path'] for diff in diffs_original]
|
||||
names_filtered = [diff['new_path'] for diff in diffs]
|
||||
get_logger().info(f"Filtered out [ignore] files for merge request {self.id_mr}", extra={
|
||||
'original_files': names_original,
|
||||
'filtered_files': names_filtered
|
||||
})
|
||||
get_logger().info(
|
||||
f"Filtered out [ignore] files for merge request {self.id_mr}",
|
||||
extra={
|
||||
'original_files': names_original,
|
||||
'filtered_files': names_filtered,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -116,22 +135,31 @@ class GitLabProvider(GitProvider):
|
||||
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
|
||||
counter_valid += 1
|
||||
if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']:
|
||||
original_file_content_str = self.get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha'])
|
||||
new_file_content_str = self.get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha'])
|
||||
original_file_content_str = self.get_pr_file_content(
|
||||
diff['old_path'], self.mr.diff_refs['base_sha']
|
||||
)
|
||||
new_file_content_str = self.get_pr_file_content(
|
||||
diff['new_path'], self.mr.diff_refs['head_sha']
|
||||
)
|
||||
else:
|
||||
if counter_valid == MAX_FILES_ALLOWED_FULL:
|
||||
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
|
||||
get_logger().info(
|
||||
f"Too many files in PR, will avoid loading full content for rest of files"
|
||||
)
|
||||
original_file_content_str = ''
|
||||
new_file_content_str = ''
|
||||
|
||||
try:
|
||||
if isinstance(original_file_content_str, bytes):
|
||||
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
|
||||
original_file_content_str = bytes.decode(
|
||||
original_file_content_str, 'utf-8'
|
||||
)
|
||||
if isinstance(new_file_content_str, bytes):
|
||||
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
|
||||
except UnicodeDecodeError:
|
||||
get_logger().warning(
|
||||
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
|
||||
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}"
|
||||
)
|
||||
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
if diff['new_file']:
|
||||
@ -144,30 +172,43 @@ class GitLabProvider(GitProvider):
|
||||
filename = diff['new_path']
|
||||
patch = diff['diff']
|
||||
if not patch:
|
||||
patch = load_large_diff(filename, new_file_content_str, original_file_content_str)
|
||||
|
||||
patch = load_large_diff(
|
||||
filename, new_file_content_str, original_file_content_str
|
||||
)
|
||||
|
||||
# count number of lines added and removed
|
||||
patch_lines = patch.splitlines(keepends=True)
|
||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
||||
num_minus_lines = len(
|
||||
[line for line in patch_lines if line.startswith('-')]
|
||||
)
|
||||
diff_files.append(
|
||||
FilePatchInfo(original_file_content_str, new_file_content_str,
|
||||
patch=patch,
|
||||
filename=filename,
|
||||
edit_type=edit_type,
|
||||
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'],
|
||||
num_plus_lines=num_plus_lines,
|
||||
num_minus_lines=num_minus_lines, ))
|
||||
FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
patch=patch,
|
||||
filename=filename,
|
||||
edit_type=edit_type,
|
||||
old_filename=None
|
||||
if diff['old_path'] == diff['new_path']
|
||||
else diff['old_path'],
|
||||
num_plus_lines=num_plus_lines,
|
||||
num_minus_lines=num_minus_lines,
|
||||
)
|
||||
)
|
||||
if invalid_files_names:
|
||||
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
|
||||
get_logger().info(
|
||||
f"Filtered out files with invalid extensions: {invalid_files_names}"
|
||||
)
|
||||
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def get_files(self) -> list:
|
||||
if not self.git_files:
|
||||
self.git_files = [change['new_path'] for change in self.mr.changes()['changes']]
|
||||
self.git_files = [
|
||||
change['new_path'] for change in self.mr.changes()['changes']
|
||||
]
|
||||
return self.git_files
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
@ -176,7 +217,9 @@ class GitLabProvider(GitProvider):
|
||||
self.mr.description = pr_body
|
||||
self.mr.save()
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Could not update merge request {self.id_mr} description: {e}")
|
||||
get_logger().exception(
|
||||
f"Could not update merge request {self.id_mr} description: {e}"
|
||||
)
|
||||
|
||||
def get_latest_commit_url(self):
|
||||
return self.mr.commits().next().web_url
|
||||
@ -184,16 +227,23 @@ class GitLabProvider(GitProvider):
|
||||
def get_comment_url(self, comment):
|
||||
return f"{self.mr.web_url}#note_{comment.id}"
|
||||
|
||||
def publish_persistent_comment(self, pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True):
|
||||
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
|
||||
def publish_persistent_comment(
|
||||
self,
|
||||
pr_comment: str,
|
||||
initial_header: str,
|
||||
update_header: bool = True,
|
||||
name='review',
|
||||
final_update_message=True,
|
||||
):
|
||||
self.publish_persistent_comment_full(
|
||||
pr_comment, initial_header, update_header, name, final_update_message
|
||||
)
|
||||
|
||||
def publish_comment(self, mr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {mr_comment}")
|
||||
get_logger().debug(
|
||||
f"Skipping publish_comment for temporary comment: {mr_comment}"
|
||||
)
|
||||
return None
|
||||
mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars)
|
||||
comment = self.mr.notes.create({'body': mr_comment})
|
||||
@ -203,7 +253,7 @@ class GitLabProvider(GitProvider):
|
||||
|
||||
def edit_comment(self, comment, body: str):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
self.mr.notes.update(comment.id,{'body': body} )
|
||||
self.mr.notes.update(comment.id, {'body': body})
|
||||
|
||||
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
@ -216,39 +266,87 @@ class GitLabProvider(GitProvider):
|
||||
discussion = self.mr.discussions.get(comment_id)
|
||||
discussion.notes.create({'body': body})
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
def publish_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
original_suggestion=None,
|
||||
):
|
||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
|
||||
relevant_line_in_file)
|
||||
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
||||
target_file, target_line_no, original_suggestion)
|
||||
(
|
||||
edit_type,
|
||||
found,
|
||||
source_line_no,
|
||||
target_file,
|
||||
target_line_no,
|
||||
) = self.search_line(relevant_file, relevant_line_in_file)
|
||||
self.send_inline_comment(
|
||||
body,
|
||||
edit_type,
|
||||
found,
|
||||
relevant_file,
|
||||
relevant_line_in_file,
|
||||
source_line_no,
|
||||
target_file,
|
||||
target_line_no,
|
||||
original_suggestion,
|
||||
)
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None):
|
||||
raise NotImplementedError("Gitlab provider does not support creating inline comments yet")
|
||||
def create_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Gitlab provider does not support creating inline comments yet"
|
||||
)
|
||||
|
||||
def create_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet")
|
||||
raise NotImplementedError(
|
||||
"Gitlab provider does not support publishing inline comments yet"
|
||||
)
|
||||
|
||||
def get_comment_body_from_comment_id(self, comment_id: int):
|
||||
comment = self.mr.notes.get(comment_id).body
|
||||
return comment
|
||||
|
||||
def send_inline_comment(self, body: str, edit_type: str, found: bool, relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
source_line_no: int, target_file: str, target_line_no: int,
|
||||
original_suggestion=None) -> None:
|
||||
def send_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
edit_type: str,
|
||||
found: bool,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
source_line_no: int,
|
||||
target_file: str,
|
||||
target_line_no: int,
|
||||
original_suggestion=None,
|
||||
) -> None:
|
||||
if not found:
|
||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
get_logger().info(
|
||||
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||
)
|
||||
else:
|
||||
# in order to have exact sha's we have to find correct diff for this change
|
||||
diff = self.get_relevant_diff(relevant_file, relevant_line_in_file)
|
||||
if diff is None:
|
||||
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
||||
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}")
|
||||
pos_obj = {'position_type': 'text',
|
||||
'new_path': target_file.filename,
|
||||
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
||||
'base_sha': diff.base_commit_sha, 'start_sha': diff.start_commit_sha, 'head_sha': diff.head_commit_sha}
|
||||
raise DiffNotFoundError(
|
||||
f"Could not get diff for merge request {self.id_mr}"
|
||||
)
|
||||
pos_obj = {
|
||||
'position_type': 'text',
|
||||
'new_path': target_file.filename,
|
||||
'old_path': target_file.old_filename
|
||||
if target_file.old_filename
|
||||
else target_file.filename,
|
||||
'base_sha': diff.base_commit_sha,
|
||||
'start_sha': diff.start_commit_sha,
|
||||
'head_sha': diff.head_commit_sha,
|
||||
}
|
||||
if edit_type == 'deletion':
|
||||
pos_obj['old_line'] = source_line_no - 1
|
||||
elif edit_type == 'addition':
|
||||
@ -256,15 +354,21 @@ class GitLabProvider(GitProvider):
|
||||
else:
|
||||
pos_obj['new_line'] = target_line_no - 1
|
||||
pos_obj['old_line'] = source_line_no - 1
|
||||
get_logger().debug(f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}")
|
||||
get_logger().debug(
|
||||
f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}"
|
||||
)
|
||||
try:
|
||||
self.mr.discussions.create({'body': body, 'position': pos_obj})
|
||||
except Exception as e:
|
||||
try:
|
||||
# fallback - create a general note on the file in the MR
|
||||
if 'suggestion_orig_location' in original_suggestion:
|
||||
line_start = original_suggestion['suggestion_orig_location']['start_line']
|
||||
line_end = original_suggestion['suggestion_orig_location']['end_line']
|
||||
line_start = original_suggestion['suggestion_orig_location'][
|
||||
'start_line'
|
||||
]
|
||||
line_end = original_suggestion['suggestion_orig_location'][
|
||||
'end_line'
|
||||
]
|
||||
old_code_snippet = original_suggestion['prev_code_snippet']
|
||||
new_code_snippet = original_suggestion['new_code_snippet']
|
||||
content = original_suggestion['suggestion_summary']
|
||||
@ -287,36 +391,49 @@ class GitLabProvider(GitProvider):
|
||||
else:
|
||||
language = ''
|
||||
link = self.get_line_link(relevant_file, line_start, line_end)
|
||||
body_fallback =f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
|
||||
body_fallback +=f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
|
||||
body_fallback = (
|
||||
f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
|
||||
)
|
||||
body_fallback += f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
|
||||
body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`"
|
||||
body_fallback+="</details>\n\n"
|
||||
diff_patch = difflib.unified_diff(old_code_snippet.split('\n'),
|
||||
new_code_snippet.split('\n'), n=999)
|
||||
body_fallback += "</details>\n\n"
|
||||
diff_patch = difflib.unified_diff(
|
||||
old_code_snippet.split('\n'),
|
||||
new_code_snippet.split('\n'),
|
||||
n=999,
|
||||
)
|
||||
patch_orig = "\n".join(diff_patch)
|
||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||
body_fallback += diff_code
|
||||
|
||||
# Create a general note on the file in the MR
|
||||
self.mr.notes.create({
|
||||
'body': body_fallback,
|
||||
'position': {
|
||||
'base_sha': diff.base_commit_sha,
|
||||
'start_sha': diff.start_commit_sha,
|
||||
'head_sha': diff.head_commit_sha,
|
||||
'position_type': 'text',
|
||||
'file_path': f'{target_file.filename}',
|
||||
self.mr.notes.create(
|
||||
{
|
||||
'body': body_fallback,
|
||||
'position': {
|
||||
'base_sha': diff.base_commit_sha,
|
||||
'start_sha': diff.start_commit_sha,
|
||||
'head_sha': diff.head_commit_sha,
|
||||
'position_type': 'text',
|
||||
'file_path': f'{target_file.filename}',
|
||||
},
|
||||
}
|
||||
})
|
||||
get_logger().debug(f"Created fallback comment in MR {self.id_mr} with position {pos_obj}")
|
||||
)
|
||||
get_logger().debug(
|
||||
f"Created fallback comment in MR {self.id_mr} with position {pos_obj}"
|
||||
)
|
||||
|
||||
# get_logger().debug(
|
||||
# f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)")
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to create comment in MR {self.id_mr}")
|
||||
get_logger().exception(
|
||||
f"Failed to create comment in MR {self.id_mr}"
|
||||
)
|
||||
|
||||
def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: str) -> Optional[dict]:
|
||||
def get_relevant_diff(
|
||||
self, relevant_file: str, relevant_line_in_file: str
|
||||
) -> Optional[dict]:
|
||||
changes = self.mr.changes() # Retrieve the changes for the merge request once
|
||||
if not changes:
|
||||
get_logger().error('No changes found for the merge request.')
|
||||
@ -327,10 +444,14 @@ class GitLabProvider(GitProvider):
|
||||
return None
|
||||
for diff in all_diffs:
|
||||
for change in changes['changes']:
|
||||
if change['new_path'] == relevant_file and relevant_line_in_file in change['diff']:
|
||||
if (
|
||||
change['new_path'] == relevant_file
|
||||
and relevant_line_in_file in change['diff']
|
||||
):
|
||||
return diff
|
||||
get_logger().debug(
|
||||
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
|
||||
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.'
|
||||
)
|
||||
return self.last_diff # fallback to last_diff if no relevant diff is found
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
@ -352,7 +473,7 @@ class GitLabProvider(GitProvider):
|
||||
if file.filename == relevant_file:
|
||||
target_file = file
|
||||
break
|
||||
range = relevant_lines_end - relevant_lines_start # no need to add 1
|
||||
range = relevant_lines_end - relevant_lines_start # no need to add 1
|
||||
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
|
||||
lines = target_file.head_file.splitlines()
|
||||
relevant_line_in_file = lines[relevant_lines_start - 1]
|
||||
@ -365,10 +486,21 @@ class GitLabProvider(GitProvider):
|
||||
found = True
|
||||
edit_type = 'addition'
|
||||
|
||||
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
||||
target_file, target_line_no, original_suggestion)
|
||||
self.send_inline_comment(
|
||||
body,
|
||||
edit_type,
|
||||
found,
|
||||
relevant_file,
|
||||
relevant_line_in_file,
|
||||
source_line_no,
|
||||
target_file,
|
||||
target_line_no,
|
||||
original_suggestion,
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}")
|
||||
get_logger().exception(
|
||||
f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}"
|
||||
)
|
||||
|
||||
# note that we publish suggestions one-by-one. so, if one fails, the rest will still be published
|
||||
return True
|
||||
@ -382,8 +514,13 @@ class GitLabProvider(GitProvider):
|
||||
edit_type = self.get_edit_type(relevant_line_in_file)
|
||||
for file in self.get_diff_files():
|
||||
if file.filename == relevant_file:
|
||||
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
|
||||
relevant_line_in_file)
|
||||
(
|
||||
edit_type,
|
||||
found,
|
||||
source_line_no,
|
||||
target_file,
|
||||
target_line_no,
|
||||
) = self.find_in_file(file, relevant_line_in_file)
|
||||
return edit_type, found, source_line_no, target_file, target_line_no
|
||||
|
||||
def find_in_file(self, file, relevant_line_in_file):
|
||||
@ -414,7 +551,10 @@ class GitLabProvider(GitProvider):
|
||||
found = True
|
||||
edit_type = self.get_edit_type(line)
|
||||
break
|
||||
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:].lstrip() in line:
|
||||
elif (
|
||||
relevant_line_in_file[0] == '+'
|
||||
and relevant_line_in_file[1:].lstrip() in line
|
||||
):
|
||||
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||
# it's a context line
|
||||
found = True
|
||||
@ -470,7 +610,11 @@ class GitLabProvider(GitProvider):
|
||||
|
||||
def get_repo_settings(self):
|
||||
try:
|
||||
contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch).decode()
|
||||
contents = (
|
||||
self.gl.projects.get(self.id_project)
|
||||
.files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch)
|
||||
.decode()
|
||||
)
|
||||
return contents
|
||||
except Exception:
|
||||
return ""
|
||||
@ -478,7 +622,9 @@ class GitLabProvider(GitProvider):
|
||||
def get_workspace_name(self):
|
||||
return self.id_project.split('/')[0]
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
||||
def add_eyes_reaction(
|
||||
self, issue_comment_id: int, disable_eyes: bool = False
|
||||
) -> Optional[int]:
|
||||
return True
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||
@ -489,7 +635,9 @@ class GitLabProvider(GitProvider):
|
||||
|
||||
path_parts = parsed_url.path.strip('/').split('/')
|
||||
if 'merge_requests' not in path_parts:
|
||||
raise ValueError("The provided URL does not appear to be a GitLab merge request URL")
|
||||
raise ValueError(
|
||||
"The provided URL does not appear to be a GitLab merge request URL"
|
||||
)
|
||||
|
||||
mr_index = path_parts.index('merge_requests')
|
||||
# Ensure there is an ID after 'merge_requests'
|
||||
@ -541,8 +689,15 @@ class GitLabProvider(GitProvider):
|
||||
"""
|
||||
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
|
||||
try:
|
||||
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
|
||||
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
|
||||
commit_messages_list = [
|
||||
commit['message'] for commit in self.mr.commits()._list
|
||||
]
|
||||
commit_messages_str = "\n".join(
|
||||
[
|
||||
f"{i + 1}. {message}"
|
||||
for i, message in enumerate(commit_messages_list)
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
commit_messages_str = ""
|
||||
if max_tokens:
|
||||
@ -556,7 +711,12 @@ class GitLabProvider(GitProvider):
|
||||
except:
|
||||
return ""
|
||||
|
||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
||||
def get_line_link(
|
||||
self,
|
||||
relevant_file: str,
|
||||
relevant_line_start: int,
|
||||
relevant_line_end: int = None,
|
||||
) -> str:
|
||||
if relevant_line_start == -1:
|
||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads"
|
||||
elif relevant_line_end:
|
||||
@ -565,7 +725,6 @@ class GitLabProvider(GitProvider):
|
||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}"
|
||||
return link
|
||||
|
||||
|
||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||
try:
|
||||
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
|
||||
@ -573,8 +732,9 @@ class GitLabProvider(GitProvider):
|
||||
if not relevant_line_str:
|
||||
return ""
|
||||
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
||||
(self.diff_files, relevant_file, relevant_line_str)
|
||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||
self.diff_files, relevant_file, relevant_line_str
|
||||
)
|
||||
|
||||
if absolute_position != -1:
|
||||
# link to right file only
|
||||
|
||||
@ -39,10 +39,16 @@ class LocalGitProvider(GitProvider):
|
||||
self._prepare_repo()
|
||||
self.diff_files = None
|
||||
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
||||
self.description_path = get_settings().get('local.description_path') \
|
||||
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
|
||||
self.review_path = get_settings().get('local.review_path') \
|
||||
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
|
||||
self.description_path = (
|
||||
get_settings().get('local.description_path')
|
||||
if get_settings().get('local.description_path') is not None
|
||||
else self.repo_path / 'description.md'
|
||||
)
|
||||
self.review_path = (
|
||||
get_settings().get('local.review_path')
|
||||
if get_settings().get('local.review_path') is not None
|
||||
else self.repo_path / 'review.md'
|
||||
)
|
||||
# inline code comments are not supported for local git repositories
|
||||
get_settings().pr_reviewer.inline_code_comments = False
|
||||
|
||||
@ -52,30 +58,43 @@ class LocalGitProvider(GitProvider):
|
||||
"""
|
||||
get_logger().debug('Preparing repository for PR-mimic generation...')
|
||||
if self.repo.is_dirty():
|
||||
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.')
|
||||
raise ValueError(
|
||||
'The repository is not in a clean state. Please commit or stash pending changes.'
|
||||
)
|
||||
if self.target_branch_name not in self.repo.heads:
|
||||
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels',
|
||||
'gfm_markdown']:
|
||||
if capability in [
|
||||
'get_issue_comments',
|
||||
'create_inline_comment',
|
||||
'publish_inline_comments',
|
||||
'get_labels',
|
||||
'gfm_markdown',
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
diffs = self.repo.head.commit.diff(
|
||||
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
|
||||
self.repo.merge_base(
|
||||
self.repo.head, self.repo.branches[self.target_branch_name]
|
||||
),
|
||||
create_patch=True,
|
||||
R=True
|
||||
R=True,
|
||||
)
|
||||
diff_files = []
|
||||
for diff_item in diffs:
|
||||
if diff_item.a_blob is not None:
|
||||
original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8')
|
||||
original_file_content_str = diff_item.a_blob.data_stream.read().decode(
|
||||
'utf-8'
|
||||
)
|
||||
else:
|
||||
original_file_content_str = "" # empty file
|
||||
if diff_item.b_blob is not None:
|
||||
new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8')
|
||||
new_file_content_str = diff_item.b_blob.data_stream.read().decode(
|
||||
'utf-8'
|
||||
)
|
||||
else:
|
||||
new_file_content_str = "" # empty file
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
@ -86,13 +105,16 @@ class LocalGitProvider(GitProvider):
|
||||
elif diff_item.renamed_file:
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
diff_files.append(
|
||||
FilePatchInfo(original_file_content_str,
|
||||
new_file_content_str,
|
||||
diff_item.diff.decode('utf-8'),
|
||||
diff_item.b_path,
|
||||
edit_type=edit_type,
|
||||
old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path
|
||||
)
|
||||
FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
diff_item.diff.decode('utf-8'),
|
||||
diff_item.b_path,
|
||||
edit_type=edit_type,
|
||||
old_filename=None
|
||||
if diff_item.a_path == diff_item.b_path
|
||||
else diff_item.a_path,
|
||||
)
|
||||
)
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
@ -102,8 +124,10 @@ class LocalGitProvider(GitProvider):
|
||||
Returns a list of files with changes in the diff.
|
||||
"""
|
||||
diff_index = self.repo.head.commit.diff(
|
||||
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
|
||||
R=True
|
||||
self.repo.merge_base(
|
||||
self.repo.head, self.repo.branches[self.target_branch_name]
|
||||
),
|
||||
R=True,
|
||||
)
|
||||
# Get the list of changed files
|
||||
diff_files = [item.a_path for item in diff_index]
|
||||
@ -119,18 +143,37 @@ class LocalGitProvider(GitProvider):
|
||||
# Write the string to the file
|
||||
file.write(pr_comment)
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
||||
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
|
||||
def publish_inline_comment(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
original_suggestion=None,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the local git provider'
|
||||
)
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the local git provider'
|
||||
)
|
||||
|
||||
def publish_code_suggestion(self, body: str, relevant_file: str,
|
||||
relevant_lines_start: int, relevant_lines_end: int):
|
||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
||||
def publish_code_suggestion(
|
||||
self,
|
||||
body: str,
|
||||
relevant_file: str,
|
||||
relevant_lines_start: int,
|
||||
relevant_lines_end: int,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
'Publishing code suggestions is not implemented for the local git provider'
|
||||
)
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
||||
raise NotImplementedError(
|
||||
'Publishing code suggestions is not implemented for the local git provider'
|
||||
)
|
||||
|
||||
def publish_labels(self, labels):
|
||||
pass # Not applicable to the local git provider, but required by the interface
|
||||
@ -158,19 +201,31 @@ class LocalGitProvider(GitProvider):
|
||||
Calculate percentage of languages in repository. Used for hunk prioritisation.
|
||||
"""
|
||||
# Get all files in repository
|
||||
filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob']
|
||||
filepaths = [
|
||||
Path(item.path)
|
||||
for item in self.repo.tree().traverse()
|
||||
if item.type == 'blob'
|
||||
]
|
||||
# Identify language by file extension and count
|
||||
lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()])
|
||||
lang_count = Counter(
|
||||
ext.lstrip('.')
|
||||
for filepath in filepaths
|
||||
for ext in [filepath.suffix.lower()]
|
||||
)
|
||||
# Convert counts to percentages
|
||||
total_files = len(filepaths)
|
||||
lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()}
|
||||
lang_percentage = {
|
||||
lang: count / total_files * 100 for lang, count in lang_count.items()
|
||||
}
|
||||
return lang_percentage
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.repo.head
|
||||
|
||||
def get_user_id(self):
|
||||
return -1 # Not used anywhere for the local provider, but required by the interface
|
||||
return (
|
||||
-1
|
||||
) # Not used anywhere for the local provider, but required by the interface
|
||||
|
||||
def get_pr_description_full(self):
|
||||
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
||||
@ -186,7 +241,11 @@ class LocalGitProvider(GitProvider):
|
||||
return self.head_branch_name
|
||||
|
||||
def get_issue_comments(self):
|
||||
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
|
||||
raise NotImplementedError(
|
||||
'Getting issue comments is not implemented for the local git provider'
|
||||
)
|
||||
|
||||
def get_pr_labels(self, update=False):
|
||||
raise NotImplementedError('Getting labels is not implemented for the local git provider')
|
||||
raise NotImplementedError(
|
||||
'Getting labels is not implemented for the local git provider'
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ from dynaconf import Dynaconf
|
||||
from starlette_context import context
|
||||
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import (get_git_provider_with_context)
|
||||
from utils.pr_agent.git_providers import get_git_provider_with_context
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
@ -20,7 +20,9 @@ def apply_repo_settings(pr_url):
|
||||
except Exception:
|
||||
repo_settings = None
|
||||
pass
|
||||
if repo_settings is None: # None is different from "", which is a valid value
|
||||
if (
|
||||
repo_settings is None
|
||||
): # None is different from "", which is a valid value
|
||||
repo_settings = git_provider.get_repo_settings()
|
||||
try:
|
||||
context["repo_settings"] = repo_settings
|
||||
@ -36,15 +38,25 @@ def apply_repo_settings(pr_url):
|
||||
os.write(fd, repo_settings)
|
||||
new_settings = Dynaconf(settings_files=[repo_settings_file])
|
||||
for section, contents in new_settings.as_dict().items():
|
||||
section_dict = copy.deepcopy(get_settings().as_dict().get(section, {}))
|
||||
section_dict = copy.deepcopy(
|
||||
get_settings().as_dict().get(section, {})
|
||||
)
|
||||
for key, value in contents.items():
|
||||
section_dict[key] = value
|
||||
get_settings().unset(section)
|
||||
get_settings().set(section, section_dict, merge=False)
|
||||
get_logger().info(f"Applying repo settings:\n{new_settings.as_dict()}")
|
||||
get_logger().info(
|
||||
f"Applying repo settings:\n{new_settings.as_dict()}"
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to apply repo {category} settings, error: {str(e)}")
|
||||
error_local = {'error': str(e), 'settings': repo_settings, 'category': category}
|
||||
get_logger().warning(
|
||||
f"Failed to apply repo {category} settings, error: {str(e)}"
|
||||
)
|
||||
error_local = {
|
||||
'error': str(e),
|
||||
'settings': repo_settings,
|
||||
'category': category,
|
||||
}
|
||||
|
||||
if error_local:
|
||||
handle_configurations_errors([error_local], git_provider)
|
||||
@ -55,7 +67,10 @@ def apply_repo_settings(pr_url):
|
||||
try:
|
||||
os.remove(repo_settings_file)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e)
|
||||
get_logger().error(
|
||||
f"Failed to remove temporary settings file {repo_settings_file}",
|
||||
e,
|
||||
)
|
||||
|
||||
# enable switching models with a short definition
|
||||
if get_settings().config.model.lower() == 'claude-3-5-sonnet':
|
||||
@ -79,13 +94,18 @@ def handle_configurations_errors(config_errors, git_provider):
|
||||
body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>"
|
||||
else:
|
||||
body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n"
|
||||
get_logger().warning(f"Sending a 'configuration error' comment to the PR", artifact={'body': body})
|
||||
get_logger().warning(
|
||||
f"Sending a 'configuration error' comment to the PR",
|
||||
artifact={'body': body},
|
||||
)
|
||||
# git_provider.publish_comment(body)
|
||||
if hasattr(git_provider, 'publish_persistent_comment'):
|
||||
git_provider.publish_persistent_comment(body,
|
||||
initial_header=header,
|
||||
update_header=False,
|
||||
final_update_message=False)
|
||||
git_provider.publish_persistent_comment(
|
||||
body,
|
||||
initial_header=header,
|
||||
update_header=False,
|
||||
final_update_message=False,
|
||||
)
|
||||
else:
|
||||
git_provider.publish_comment(body)
|
||||
except Exception as e:
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.identity_providers.default_identity_provider import \
|
||||
DefaultIdentityProvider
|
||||
from utils.pr_agent.identity_providers.default_identity_provider import (
|
||||
DefaultIdentityProvider,
|
||||
)
|
||||
|
||||
_IDENTITY_PROVIDERS = {
|
||||
'default': DefaultIdentityProvider
|
||||
}
|
||||
_IDENTITY_PROVIDERS = {'default': DefaultIdentityProvider}
|
||||
|
||||
|
||||
def get_identity_provider():
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from utils.pr_agent.identity_providers.identity_provider import (Eligibility,
|
||||
IdentityProvider)
|
||||
from utils.pr_agent.identity_providers.identity_provider import (
|
||||
Eligibility,
|
||||
IdentityProvider,
|
||||
)
|
||||
|
||||
|
||||
class DefaultIdentityProvider(IdentityProvider):
|
||||
|
||||
@ -30,7 +30,9 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
|
||||
if type(level) is not int:
|
||||
level = logging.INFO
|
||||
|
||||
if fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0": # better debugging github_app
|
||||
if (
|
||||
fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0"
|
||||
): # better debugging github_app
|
||||
logger.remove(None)
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
@ -40,7 +42,7 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
|
||||
colorize=False,
|
||||
serialize=True,
|
||||
)
|
||||
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
|
||||
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
|
||||
logger.remove(None)
|
||||
logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter)
|
||||
|
||||
|
||||
@ -8,10 +8,14 @@ def get_secret_provider():
|
||||
provider_id = get_settings().config.secret_provider
|
||||
if provider_id == 'google_cloud_storage':
|
||||
try:
|
||||
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import \
|
||||
GoogleCloudStorageSecretProvider
|
||||
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import (
|
||||
GoogleCloudStorageSecretProvider,
|
||||
)
|
||||
|
||||
return GoogleCloudStorageSecretProvider()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e
|
||||
raise ValueError(
|
||||
f"Failed to initialize google_cloud_storage secret provider {provider_id}"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("Unknown SECRET_PROVIDER")
|
||||
|
||||
@ -9,12 +9,15 @@ from utils.pr_agent.secret_providers.secret_provider import SecretProvider
|
||||
class GoogleCloudStorageSecretProvider(SecretProvider):
|
||||
def __init__(self):
|
||||
try:
|
||||
self.client = storage.Client.from_service_account_info(ujson.loads(get_settings().google_cloud_storage.
|
||||
service_account))
|
||||
self.client = storage.Client.from_service_account_info(
|
||||
ujson.loads(get_settings().google_cloud_storage.service_account)
|
||||
)
|
||||
self.bucket_name = get_settings().google_cloud_storage.bucket_name
|
||||
self.bucket = self.client.bucket(self.bucket_name)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to initialize Google Cloud Storage Secret Provider: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to initialize Google Cloud Storage Secret Provider: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def get_secret(self, secret_name: str) -> str:
|
||||
@ -22,7 +25,9 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
|
||||
blob = self.bucket.blob(secret_name)
|
||||
return blob.download_as_string()
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to get secret {secret_name} from Google Cloud Storage: {e}")
|
||||
get_logger().warning(
|
||||
f"Failed to get secret {secret_name} from Google Cloud Storage: {e}"
|
||||
)
|
||||
return ""
|
||||
|
||||
def store_secret(self, secret_name: str, secret_value: str):
|
||||
@ -30,5 +35,7 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
|
||||
blob = self.bucket.blob(secret_name)
|
||||
blob.upload_from_string(secret_value)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to store secret {secret_name} in Google Cloud Storage: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to store secret {secret_name} in Google Cloud Storage: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SecretProvider(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_secret(self, secret_name: str) -> str:
|
||||
pass
|
||||
|
||||
@ -33,6 +33,7 @@ azure_devops_server = get_settings().get("azure_devops_server")
|
||||
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
|
||||
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
|
||||
|
||||
|
||||
def handle_request(
|
||||
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
|
||||
):
|
||||
@ -52,20 +53,27 @@ def handle_request(
|
||||
# currently only basic auth is supported with azure webhooks
|
||||
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
|
||||
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
|
||||
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
|
||||
if not (is_user_ok and is_pass_ok):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Incorrect username or password.',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
|
||||
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
|
||||
if not (is_user_ok and is_pass_ok):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Incorrect username or password.',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
|
||||
|
||||
async def _perform_commands_azure(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict):
|
||||
async def _perform_commands_azure(
|
||||
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict
|
||||
):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
|
||||
if (
|
||||
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||
): # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(
|
||||
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
|
||||
**log_context,
|
||||
)
|
||||
return
|
||||
commands = get_settings().get(f"azure_devops_server.{commands_conf}")
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
@ -92,22 +100,38 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
actions = []
|
||||
if data["eventType"] == "git.pullrequest.created":
|
||||
# API V1 (latest)
|
||||
pr_url = unquote(data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git"))
|
||||
pr_url = unquote(
|
||||
data["resource"]["_links"]["web"]["href"].replace(
|
||||
"_apis/git/repositories", "_git"
|
||||
)
|
||||
)
|
||||
log_context["event"] = data["eventType"]
|
||||
log_context["api_url"] = pr_url
|
||||
await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context)
|
||||
return
|
||||
elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]:
|
||||
elif (
|
||||
data["eventType"] == "ms.vss-code.git-pullrequest-comment-event"
|
||||
and "content" in data["resource"]["comment"]
|
||||
):
|
||||
if available_commands_rgx.match(data["resource"]["comment"]["content"]):
|
||||
if(data["resourceVersion"] == "2.0"):
|
||||
if data["resourceVersion"] == "2.0":
|
||||
repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
|
||||
pr_url = unquote(f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}')
|
||||
pr_url = unquote(
|
||||
f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}'
|
||||
)
|
||||
actions = [data["resource"]["comment"]["content"]]
|
||||
else:
|
||||
# API V1 not supported as it does not contain the PR URL
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})),
|
||||
return (
|
||||
JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=json.dumps(
|
||||
{
|
||||
"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"
|
||||
}
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@ -132,17 +156,21 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
content=json.dumps({"message": "Internal server error"}),
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggered successfully"})
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
content=jsonable_encoder({"message": "webhook triggered successfully"}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def start():
|
||||
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
|
||||
app.include_router(router)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start()
|
||||
|
||||
@ -27,7 +27,9 @@ from utils.pr_agent.secret_providers import get_secret_provider
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
router = APIRouter()
|
||||
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||
secret_provider = (
|
||||
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||
)
|
||||
|
||||
|
||||
async def get_bearer_token(shared_secret: str, client_key: str):
|
||||
@ -44,12 +46,12 @@ async def get_bearer_token(shared_secret: str, client_key: str):
|
||||
"exp": now + 240,
|
||||
"qsh": qsh,
|
||||
"sub": client_key,
|
||||
}
|
||||
}
|
||||
token = jwt.encode(payload, shared_secret, algorithm="HS256")
|
||||
payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt'
|
||||
headers = {
|
||||
'Authorization': f'JWT {token}',
|
||||
'Content-Type': 'application/x-www-form-urlencoded'
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
bearer_token = response.json()["access_token"]
|
||||
@ -58,6 +60,7 @@ async def get_bearer_token(shared_secret: str, client_key: str):
|
||||
get_logger().error(f"Failed to get bearer token: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def handle_manifest(request: Request, response: Response):
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -66,7 +69,9 @@ async def handle_manifest(request: Request, response: Response):
|
||||
manifest = manifest.replace("app_key", get_settings().bitbucket.app_key)
|
||||
manifest = manifest.replace("base_url", get_settings().bitbucket.base_url)
|
||||
except:
|
||||
get_logger().error("Failed to replace api_key in Bitbucket manifest, trying to continue")
|
||||
get_logger().error(
|
||||
"Failed to replace api_key in Bitbucket manifest, trying to continue"
|
||||
)
|
||||
manifest_obj = json.loads(manifest)
|
||||
return JSONResponse(manifest_obj)
|
||||
|
||||
@ -83,10 +88,16 @@ def _get_username(data):
|
||||
return ""
|
||||
|
||||
|
||||
async def _perform_commands_bitbucket(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict):
|
||||
async def _perform_commands_bitbucket(
|
||||
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
|
||||
):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
|
||||
if (
|
||||
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||
): # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(
|
||||
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
|
||||
)
|
||||
return
|
||||
if data.get("event", "") == "pullrequest:created":
|
||||
if not should_process_pr_logic(data):
|
||||
@ -132,7 +143,9 @@ def should_process_pr_logic(data) -> bool:
|
||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||
if ignore_pr_users and sender:
|
||||
if sender in ignore_pr_users:
|
||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
|
||||
get_logger().info(
|
||||
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
|
||||
)
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific titles
|
||||
@ -140,20 +153,34 @@ def should_process_pr_logic(data) -> bool:
|
||||
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||
if not isinstance(ignore_pr_title_re, list):
|
||||
ignore_pr_title_re = [ignore_pr_title_re]
|
||||
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
|
||||
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
|
||||
if ignore_pr_title_re and any(
|
||||
re.search(regex, title) for regex in ignore_pr_title_re
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
|
||||
)
|
||||
return False
|
||||
|
||||
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
||||
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
||||
if (ignore_pr_source_branches or ignore_pr_target_branches):
|
||||
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
|
||||
ignore_pr_source_branches = get_settings().get(
|
||||
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
|
||||
)
|
||||
ignore_pr_target_branches = get_settings().get(
|
||||
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
|
||||
)
|
||||
if ignore_pr_source_branches or ignore_pr_target_branches:
|
||||
if any(
|
||||
re.search(regex, source_branch) for regex in ignore_pr_source_branches
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
|
||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
|
||||
)
|
||||
return False
|
||||
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
|
||||
if any(
|
||||
re.search(regex, target_branch) for regex in ignore_pr_target_branches
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
|
||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||
@ -195,7 +222,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
||||
client_key = claims["iss"]
|
||||
secrets = json.loads(secret_provider.get_secret(client_key))
|
||||
shared_secret = secrets["shared_secret"]
|
||||
jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"])
|
||||
jwt.decode(
|
||||
input_jwt, shared_secret, audience=client_key, algorithms=["HS256"]
|
||||
)
|
||||
bearer_token = await get_bearer_token(shared_secret, client_key)
|
||||
context['bitbucket_bearer_token'] = bearer_token
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
@ -208,28 +237,41 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
||||
if pr_url:
|
||||
with get_logger().contextualize(**log_context):
|
||||
apply_repo_settings(pr_url)
|
||||
if get_identity_provider().verify_eligibility("bitbucket",
|
||||
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
if (
|
||||
get_identity_provider().verify_eligibility(
|
||||
"bitbucket", sender_id, pr_url
|
||||
)
|
||||
is not Eligibility.NOT_ELIGIBLE
|
||||
):
|
||||
if get_settings().get("bitbucket_app.pr_commands"):
|
||||
await _perform_commands_bitbucket("pr_commands", PRAgent(), pr_url, log_context, data)
|
||||
await _perform_commands_bitbucket(
|
||||
"pr_commands", PRAgent(), pr_url, log_context, data
|
||||
)
|
||||
elif event == "pullrequest:comment_created":
|
||||
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
||||
log_context["api_url"] = pr_url
|
||||
log_context["event"] = "comment"
|
||||
comment_body = data["data"]["comment"]["content"]["raw"]
|
||||
with get_logger().contextualize(**log_context):
|
||||
if get_identity_provider().verify_eligibility("bitbucket",
|
||||
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
if (
|
||||
get_identity_provider().verify_eligibility(
|
||||
"bitbucket", sender_id, pr_url
|
||||
)
|
||||
is not Eligibility.NOT_ELIGIBLE
|
||||
):
|
||||
await agent.handle_request(pr_url, comment_body)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle webhook: {e}")
|
||||
|
||||
background_tasks.add_task(inner)
|
||||
return "OK"
|
||||
|
||||
|
||||
@router.get("/webhook")
|
||||
async def handle_github_webhooks(request: Request, response: Response):
|
||||
return "Webhook server online!"
|
||||
|
||||
|
||||
@router.post("/installed")
|
||||
async def handle_installed_webhooks(request: Request, response: Response):
|
||||
try:
|
||||
@ -240,15 +282,13 @@ async def handle_installed_webhooks(request: Request, response: Response):
|
||||
shared_secret = data["sharedSecret"]
|
||||
client_key = data["clientKey"]
|
||||
username = data["principal"]["username"]
|
||||
secrets = {
|
||||
"shared_secret": shared_secret,
|
||||
"client_key": client_key
|
||||
}
|
||||
secrets = {"shared_secret": shared_secret, "client_key": client_key}
|
||||
secret_provider.store_secret(username, json.dumps(secrets))
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to register user: {e}")
|
||||
return JSONResponse({"error": "Unable to register user"}, status_code=500)
|
||||
|
||||
|
||||
@router.post("/uninstalled")
|
||||
async def handle_uninstalled_webhooks(request: Request, response: Response):
|
||||
get_logger().info("handle_uninstalled_webhooks")
|
||||
|
||||
@ -40,10 +40,12 @@ def handle_request(
|
||||
|
||||
background_tasks.add_task(inner)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def redirect_to_webhook():
|
||||
return RedirectResponse(url="/webhook")
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
log_context = {"server_type": "bitbucket_server"}
|
||||
@ -55,7 +57,8 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
body_bytes = await request.body()
|
||||
if body_bytes.decode('utf-8') == '{"test": true}':
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "connection test successful"})
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "connection test successful"}),
|
||||
)
|
||||
signature_header = request.headers.get("x-hub-signature", None)
|
||||
verify_signature(body_bytes, webhook_secret, signature_header)
|
||||
@ -73,11 +76,18 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
|
||||
if data["eventKey"] == "pr:opened":
|
||||
apply_repo_settings(pr_url)
|
||||
if get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", **log_context)
|
||||
if (
|
||||
get_settings().config.disable_auto_feedback
|
||||
): # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(
|
||||
f"Auto feedback is disabled, skipping auto commands for PR {pr_url}",
|
||||
**log_context,
|
||||
)
|
||||
return
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
commands_to_run.extend(_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS'))
|
||||
commands_to_run.extend(
|
||||
_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS')
|
||||
)
|
||||
elif data["eventKey"] == "pr:comment:added":
|
||||
commands_to_run.append(data["comment"]["text"])
|
||||
else:
|
||||
@ -116,6 +126,7 @@ async def _run_commands_sequentially(commands: List[str], url: str, log_context:
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle command: {command} , error: {e}")
|
||||
|
||||
|
||||
def _process_command(command: str, url) -> str:
|
||||
# don't think we need this
|
||||
apply_repo_settings(url)
|
||||
@ -142,11 +153,13 @@ def _to_list(command_string: str) -> list:
|
||||
raise ValueError(f"Invalid command string: {e}")
|
||||
|
||||
|
||||
def _get_commands_list_from_settings(setting_key:str ) -> list:
|
||||
def _get_commands_list_from_settings(setting_key: str) -> list:
|
||||
try:
|
||||
return get_settings().get(setting_key, [])
|
||||
except ValueError as e:
|
||||
get_logger().error(f"Failed to get commands list from settings {setting_key}: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to get commands list from settings {setting_key}: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
|
||||
@ -40,12 +40,10 @@ async def handle_gerrit_request(action: Action, item: Item):
|
||||
if action == Action.ask:
|
||||
if not item.msg:
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="msg is required for ask command"
|
||||
status_code=400, detail="msg is required for ask command"
|
||||
)
|
||||
await PRAgent().handle_request(
|
||||
f"{item.project}:{item.refspec}",
|
||||
f"/{item.msg.strip()}"
|
||||
f"{item.project}:{item.refspec}", f"/{item.msg.strip()}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,12 @@ def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str,
|
||||
try:
|
||||
value = get_settings().get(key, default)
|
||||
except AttributeError: # TBD still need to debug why this happens on GitHub Actions
|
||||
value = os.getenv(key, None) or os.getenv(key.upper(), None) or os.getenv(key.lower(), None) or default
|
||||
value = (
|
||||
os.getenv(key, None)
|
||||
or os.getenv(key.upper(), None)
|
||||
or os.getenv(key.lower(), None)
|
||||
or default
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
@ -76,16 +81,24 @@ async def run_action():
|
||||
pr_url = event_payload.get("pull_request", {}).get("html_url")
|
||||
if pr_url:
|
||||
apply_repo_settings(pr_url)
|
||||
get_logger().info(f"enable_custom_labels: {get_settings().config.enable_custom_labels}")
|
||||
get_logger().info(
|
||||
f"enable_custom_labels: {get_settings().config.enable_custom_labels}"
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().info(f"github action: failed to apply repo settings: {e}")
|
||||
|
||||
# Handle pull request opened event
|
||||
if GITHUB_EVENT_NAME == "pull_request" or GITHUB_EVENT_NAME == "pull_request_target":
|
||||
if (
|
||||
GITHUB_EVENT_NAME == "pull_request"
|
||||
or GITHUB_EVENT_NAME == "pull_request_target"
|
||||
):
|
||||
action = event_payload.get("action")
|
||||
|
||||
# Retrieve the list of actions from the configuration
|
||||
pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"])
|
||||
pr_actions = get_settings().get(
|
||||
"GITHUB_ACTION_CONFIG.PR_ACTIONS",
|
||||
["opened", "reopened", "ready_for_review", "review_requested"],
|
||||
)
|
||||
|
||||
if action in pr_actions:
|
||||
pr_url = event_payload.get("pull_request", {}).get("url")
|
||||
@ -93,18 +106,30 @@ async def run_action():
|
||||
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
|
||||
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
|
||||
if auto_review is None:
|
||||
auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None)
|
||||
auto_review = get_setting_or_env(
|
||||
"GITHUB_ACTION_CONFIG.AUTO_REVIEW", None
|
||||
)
|
||||
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
|
||||
if auto_describe is None:
|
||||
auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None)
|
||||
auto_describe = get_setting_or_env(
|
||||
"GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None
|
||||
)
|
||||
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
|
||||
if auto_improve is None:
|
||||
auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None)
|
||||
auto_improve = get_setting_or_env(
|
||||
"GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None
|
||||
)
|
||||
|
||||
# Set the configuration for auto actions
|
||||
get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto
|
||||
get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled
|
||||
get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}")
|
||||
get_settings().config.is_auto_command = (
|
||||
True # Set the flag to indicate that the command is auto
|
||||
)
|
||||
get_settings().pr_description.final_update_message = (
|
||||
False # No final update message when auto_describe is enabled
|
||||
)
|
||||
get_logger().info(
|
||||
f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}"
|
||||
)
|
||||
|
||||
# invoke by default all three tools
|
||||
if auto_describe is None or is_true(auto_describe):
|
||||
@ -117,7 +142,10 @@ async def run_action():
|
||||
get_logger().info(f"Skipping action: {action}")
|
||||
|
||||
# Handle issue comment event
|
||||
elif GITHUB_EVENT_NAME == "issue_comment" or GITHUB_EVENT_NAME == "pull_request_review_comment":
|
||||
elif (
|
||||
GITHUB_EVENT_NAME == "issue_comment"
|
||||
or GITHUB_EVENT_NAME == "pull_request_review_comment"
|
||||
):
|
||||
action = event_payload.get("action")
|
||||
if action in ["created", "edited"]:
|
||||
comment_body = event_payload.get("comment", {}).get("body")
|
||||
@ -133,9 +161,15 @@ async def run_action():
|
||||
disable_eyes = False
|
||||
# check if issue is pull request
|
||||
if event_payload.get("issue", {}).get("pull_request"):
|
||||
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
|
||||
url = (
|
||||
event_payload.get("issue", {})
|
||||
.get("pull_request", {})
|
||||
.get("url")
|
||||
)
|
||||
is_pr = True
|
||||
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
|
||||
elif event_payload.get("comment", {}).get(
|
||||
"pull_request_url"
|
||||
): # for 'pull_request_review_comment
|
||||
url = event_payload.get("comment", {}).get("pull_request_url")
|
||||
is_pr = True
|
||||
disable_eyes = True
|
||||
@ -148,9 +182,11 @@ async def run_action():
|
||||
provider = get_git_provider()(pr_url=url)
|
||||
if is_pr:
|
||||
await PRAgent().handle_request(
|
||||
url, body, notify=lambda: provider.add_eyes_reaction(
|
||||
url,
|
||||
body,
|
||||
notify=lambda: provider.add_eyes_reaction(
|
||||
comment_id, disable_eyes=disable_eyes
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
await PRAgent().handle_request(url, body)
|
||||
|
||||
@ -15,8 +15,7 @@ from starlette_context.middleware import RawContextMiddleware
|
||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||
from utils.pr_agent.config_loader import get_settings, global_settings
|
||||
from utils.pr_agent.git_providers import (get_git_provider,
|
||||
get_git_provider_with_context)
|
||||
from utils.pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||
from utils.pr_agent.identity_providers import get_identity_provider
|
||||
from utils.pr_agent.identity_providers.identity_provider import Eligibility
|
||||
@ -35,7 +34,9 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/api/v1/github_webhooks")
|
||||
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request, response: Response):
|
||||
async def handle_github_webhooks(
|
||||
background_tasks: BackgroundTasks, request: Request, response: Response
|
||||
):
|
||||
"""
|
||||
Receives and processes incoming GitHub webhook requests.
|
||||
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
|
||||
@ -49,7 +50,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
||||
context["installation_id"] = installation_id
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
context["git_provider"] = {}
|
||||
background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None))
|
||||
background_tasks.add_task(
|
||||
handle_request, body, event=request.headers.get("X-GitHub-Event", None)
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
@ -73,35 +76,61 @@ async def get_body(request):
|
||||
return body
|
||||
|
||||
|
||||
_duplicate_push_triggers = DefaultDictWithTimeout(ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
|
||||
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.Condition, ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
|
||||
_duplicate_push_triggers = DefaultDictWithTimeout(
|
||||
ttl=get_settings().github_app.push_trigger_pending_tasks_ttl
|
||||
)
|
||||
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(
|
||||
asyncio.locks.Condition,
|
||||
ttl=get_settings().github_app.push_trigger_pending_tasks_ttl,
|
||||
)
|
||||
|
||||
async def handle_comments_on_pr(body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent):
|
||||
|
||||
async def handle_comments_on_pr(
|
||||
body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent,
|
||||
):
|
||||
if "comment" not in body:
|
||||
return {}
|
||||
comment_body = body.get("comment", {}).get("body")
|
||||
if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"):
|
||||
if (
|
||||
comment_body
|
||||
and isinstance(comment_body, str)
|
||||
and not comment_body.lstrip().startswith("/")
|
||||
):
|
||||
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
|
||||
comment_body_split = comment_body.split('/ask')
|
||||
comment_body = '/ask' + comment_body_split[1] +' \n' +comment_body_split[0].strip().lstrip('>')
|
||||
get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}")
|
||||
comment_body = (
|
||||
'/ask'
|
||||
+ comment_body_split[1]
|
||||
+ ' \n'
|
||||
+ comment_body_split[0].strip().lstrip('>')
|
||||
)
|
||||
get_logger().info(
|
||||
f"Reformatting comment_body so command is at the beginning: {comment_body}"
|
||||
)
|
||||
else:
|
||||
get_logger().info("Ignoring comment not starting with /")
|
||||
return {}
|
||||
disable_eyes = False
|
||||
if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]:
|
||||
if (
|
||||
"issue" in body
|
||||
and "pull_request" in body["issue"]
|
||||
and "url" in body["issue"]["pull_request"]
|
||||
):
|
||||
api_url = body["issue"]["pull_request"]["url"]
|
||||
elif "comment" in body and "pull_request_url" in body["comment"]:
|
||||
api_url = body["comment"]["pull_request_url"]
|
||||
try:
|
||||
if ('/ask' in comment_body and
|
||||
'subject_type' in body["comment"] and body["comment"]["subject_type"] == "line"):
|
||||
if (
|
||||
'/ask' in comment_body
|
||||
and 'subject_type' in body["comment"]
|
||||
and body["comment"]["subject_type"] == "line"
|
||||
):
|
||||
# comment on a code line in the "files changed" tab
|
||||
comment_body = handle_line_comments(body, comment_body)
|
||||
disable_eyes = True
|
||||
@ -113,46 +142,75 @@ async def handle_comments_on_pr(body: Dict[str, Any],
|
||||
comment_id = body.get("comment", {}).get("id")
|
||||
provider = get_git_provider_with_context(pr_url=api_url)
|
||||
with get_logger().contextualize(**log_context):
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
|
||||
await agent.handle_request(api_url, comment_body,
|
||||
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
|
||||
if (
|
||||
get_identity_provider().verify_eligibility("github", sender_id, api_url)
|
||||
is not Eligibility.NOT_ELIGIBLE
|
||||
):
|
||||
get_logger().info(
|
||||
f"Processing comment on PR {api_url=}, comment_body={comment_body}"
|
||||
)
|
||||
await agent.handle_request(
|
||||
api_url,
|
||||
comment_body,
|
||||
notify=lambda: provider.add_eyes_reaction(
|
||||
comment_id, disable_eyes=disable_eyes
|
||||
),
|
||||
)
|
||||
else:
|
||||
get_logger().info(f"User {sender=} is not eligible to process comment on PR {api_url=}")
|
||||
get_logger().info(
|
||||
f"User {sender=} is not eligible to process comment on PR {api_url=}"
|
||||
)
|
||||
|
||||
async def handle_new_pr_opened(body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent):
|
||||
|
||||
async def handle_new_pr_opened(
|
||||
body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent,
|
||||
):
|
||||
title = body.get("pull_request", {}).get("title", "")
|
||||
|
||||
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
||||
if not (pull_request and api_url):
|
||||
get_logger().info(f"Invalid PR event: {action=} {api_url=}")
|
||||
return {}
|
||||
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review']
|
||||
if (
|
||||
action in get_settings().github_app.handle_pr_actions
|
||||
): # ['opened', 'reopened', 'ready_for_review']
|
||||
# logic to ignore PRs with specific titles (e.g. "[Auto] ...")
|
||||
apply_repo_settings(api_url)
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
|
||||
if (
|
||||
get_identity_provider().verify_eligibility("github", sender_id, api_url)
|
||||
is not Eligibility.NOT_ELIGIBLE
|
||||
):
|
||||
await _perform_auto_commands_github(
|
||||
"pr_commands", agent, body, api_url, log_context
|
||||
)
|
||||
else:
|
||||
get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}")
|
||||
get_logger().info(
|
||||
f"User {sender=} is not eligible to process PR {api_url=}"
|
||||
)
|
||||
|
||||
async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent):
|
||||
|
||||
async def handle_push_trigger_for_new_commits(
|
||||
body: Dict[str, Any],
|
||||
event: str,
|
||||
sender: str,
|
||||
sender_id: str,
|
||||
action: str,
|
||||
log_context: Dict[str, Any],
|
||||
agent: PRAgent,
|
||||
):
|
||||
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
||||
if not (pull_request and api_url):
|
||||
return {}
|
||||
|
||||
apply_repo_settings(api_url) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
|
||||
apply_repo_settings(
|
||||
api_url
|
||||
) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
|
||||
if not get_settings().github_app.handle_push_trigger:
|
||||
return {}
|
||||
|
||||
@ -162,7 +220,10 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
||||
merge_commit_sha = pull_request.get("merge_commit_sha")
|
||||
if before_sha == after_sha:
|
||||
return {}
|
||||
if get_settings().github_app.push_trigger_ignore_merge_commits and after_sha == merge_commit_sha:
|
||||
if (
|
||||
get_settings().github_app.push_trigger_ignore_merge_commits
|
||||
and after_sha == merge_commit_sha
|
||||
):
|
||||
return {}
|
||||
|
||||
# Prevent triggering multiple times for subsequent push triggers when one is enough:
|
||||
@ -172,7 +233,9 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
||||
# more commits may have been pushed that led to the subsequent events,
|
||||
# so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting.
|
||||
current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0)
|
||||
max_active_tasks = 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
|
||||
max_active_tasks = (
|
||||
2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
|
||||
)
|
||||
if current_active_tasks < max_active_tasks:
|
||||
# first task can enter, and second tasks too if backlog is enabled
|
||||
get_logger().info(
|
||||
@ -191,12 +254,21 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
||||
f"Waiting to process push trigger for {api_url=} because the first task is still in progress"
|
||||
)
|
||||
await _pending_task_duplicate_push_conditions[api_url].wait()
|
||||
get_logger().info(f"Finished waiting to process push trigger for {api_url=} - continue with flow")
|
||||
get_logger().info(
|
||||
f"Finished waiting to process push trigger for {api_url=} - continue with flow"
|
||||
)
|
||||
|
||||
try:
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
|
||||
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
|
||||
if (
|
||||
get_identity_provider().verify_eligibility("github", sender_id, api_url)
|
||||
is not Eligibility.NOT_ELIGIBLE
|
||||
):
|
||||
get_logger().info(
|
||||
f"Performing incremental review for {api_url=} because of {event=} and {action=}"
|
||||
)
|
||||
await _perform_auto_commands_github(
|
||||
"push_commands", agent, body, api_url, log_context
|
||||
)
|
||||
|
||||
finally:
|
||||
# release the waiting task block
|
||||
@ -213,7 +285,12 @@ def handle_closed_pr(body, event, action, log_context):
|
||||
api_url = pull_request.get("url", "")
|
||||
pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request)
|
||||
log_context["api_url"] = api_url
|
||||
get_logger().info("PR-Agent statistics for closed PR", analytics=True, pr_statistics=pr_statistics, **log_context)
|
||||
get_logger().info(
|
||||
"PR-Agent statistics for closed PR",
|
||||
analytics=True,
|
||||
pr_statistics=pr_statistics,
|
||||
**log_context,
|
||||
)
|
||||
|
||||
|
||||
def get_log_context(body, event, action, build_number):
|
||||
@ -228,9 +305,18 @@ def get_log_context(body, event, action, build_number):
|
||||
git_org = body.get("organization", {}).get("login", "")
|
||||
installation_id = body.get("installation", {}).get("id", "")
|
||||
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
|
||||
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app",
|
||||
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name,
|
||||
"repo": repo, "git_org": git_org, "installation_id": installation_id}
|
||||
log_context = {
|
||||
"action": action,
|
||||
"event": event,
|
||||
"sender": sender,
|
||||
"server_type": "github_app",
|
||||
"request_id": uuid.uuid4().hex,
|
||||
"build_number": build_number,
|
||||
"app_name": app_name,
|
||||
"repo": repo,
|
||||
"git_org": git_org,
|
||||
"installation_id": installation_id,
|
||||
}
|
||||
except Exception as e:
|
||||
get_logger().error("Failed to get log context", e)
|
||||
log_context = {}
|
||||
@ -240,7 +326,10 @@ def get_log_context(body, event, action, build_number):
|
||||
def is_bot_user(sender, sender_type):
|
||||
try:
|
||||
# logic to ignore PRs opened by bot
|
||||
if get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) and sender_type == "Bot":
|
||||
if (
|
||||
get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False)
|
||||
and sender_type == "Bot"
|
||||
):
|
||||
if 'pr-agent' not in sender:
|
||||
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
|
||||
return True
|
||||
@ -262,7 +351,9 @@ def should_process_pr_logic(body) -> bool:
|
||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||
if ignore_pr_users and sender:
|
||||
if sender in ignore_pr_users:
|
||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
|
||||
get_logger().info(
|
||||
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
|
||||
)
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific titles
|
||||
@ -270,8 +361,12 @@ def should_process_pr_logic(body) -> bool:
|
||||
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||
if not isinstance(ignore_pr_title_re, list):
|
||||
ignore_pr_title_re = [ignore_pr_title_re]
|
||||
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
|
||||
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
|
||||
if ignore_pr_title_re and any(
|
||||
re.search(regex, title) for regex in ignore_pr_title_re
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
|
||||
)
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific labels or source branches or target branches.
|
||||
@ -280,20 +375,32 @@ def should_process_pr_logic(body) -> bool:
|
||||
labels = [label['name'] for label in pr_labels]
|
||||
if any(label in ignore_pr_labels for label in labels):
|
||||
labels_str = ", ".join(labels)
|
||||
get_logger().info(f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings")
|
||||
get_logger().info(
|
||||
f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings"
|
||||
)
|
||||
return False
|
||||
|
||||
# logic to ignore PRs with specific source or target branches
|
||||
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
||||
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
||||
ignore_pr_source_branches = get_settings().get(
|
||||
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
|
||||
)
|
||||
ignore_pr_target_branches = get_settings().get(
|
||||
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
|
||||
)
|
||||
if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches):
|
||||
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
|
||||
if any(
|
||||
re.search(regex, source_branch) for regex in ignore_pr_source_branches
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
|
||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
|
||||
)
|
||||
return False
|
||||
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
|
||||
if any(
|
||||
re.search(regex, target_branch) for regex in ignore_pr_target_branches
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
|
||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||
@ -308,11 +415,15 @@ async def handle_request(body: Dict[str, Any], event: str):
|
||||
body: The request body.
|
||||
event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.).
|
||||
"""
|
||||
action = body.get("action") # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
|
||||
action = body.get(
|
||||
"action"
|
||||
) # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
|
||||
if not action:
|
||||
return {}
|
||||
agent = PRAgent()
|
||||
log_context, sender, sender_id, sender_type = get_log_context(body, event, action, build_number)
|
||||
log_context, sender, sender_id, sender_type = get_log_context(
|
||||
body, event, action, build_number
|
||||
)
|
||||
|
||||
# logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches
|
||||
if is_bot_user(sender, sender_type) and 'check_run' not in body:
|
||||
@ -327,21 +438,29 @@ async def handle_request(body: Dict[str, Any], event: str):
|
||||
# handle comments on PRs
|
||||
elif action == 'created':
|
||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent)
|
||||
await handle_comments_on_pr(
|
||||
body, event, sender, sender_id, action, log_context, agent
|
||||
)
|
||||
# handle new PRs
|
||||
elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
|
||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent)
|
||||
await handle_new_pr_opened(
|
||||
body, event, sender, sender_id, action, log_context, agent
|
||||
)
|
||||
elif event == "issue_comment" and 'edited' in action:
|
||||
pass # handle_checkbox_clicked
|
||||
pass # handle_checkbox_clicked
|
||||
# handle pull_request event with synchronize action - "push trigger" for new commits
|
||||
elif event == 'pull_request' and action == 'synchronize':
|
||||
await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent)
|
||||
await handle_push_trigger_for_new_commits(
|
||||
body, event, sender, sender_id, action, log_context, agent
|
||||
)
|
||||
elif event == 'pull_request' and action == 'closed':
|
||||
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
|
||||
handle_closed_pr(body, event, action, log_context)
|
||||
else:
|
||||
get_logger().info(f"event {event=} action {action=} does not require any handling")
|
||||
get_logger().info(
|
||||
f"event {event=} action {action=} does not require any handling"
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
@ -362,7 +481,9 @@ def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str:
|
||||
return comment_body
|
||||
|
||||
|
||||
def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tuple[Dict[str, Any], str]:
|
||||
def _check_pull_request_event(
|
||||
action: str, body: dict, log_context: dict
|
||||
) -> Tuple[Dict[str, Any], str]:
|
||||
invalid_result = {}, ""
|
||||
pull_request = body.get("pull_request")
|
||||
if not pull_request:
|
||||
@ -373,19 +494,28 @@ def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tup
|
||||
log_context["api_url"] = api_url
|
||||
if pull_request.get("draft", True) or pull_request.get("state") != "open":
|
||||
return invalid_result
|
||||
if action in ("review_requested", "synchronize") and pull_request.get("created_at") == pull_request.get("updated_at"):
|
||||
if action in ("review_requested", "synchronize") and pull_request.get(
|
||||
"created_at"
|
||||
) == pull_request.get("updated_at"):
|
||||
# avoid double reviews when opening a PR for the first time
|
||||
return invalid_result
|
||||
return pull_request, api_url
|
||||
|
||||
|
||||
async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body: dict, api_url: str,
|
||||
log_context: dict):
|
||||
async def _perform_auto_commands_github(
|
||||
commands_conf: str, agent: PRAgent, body: dict, api_url: str, log_context: dict
|
||||
):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
|
||||
if (
|
||||
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||
): # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(
|
||||
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
|
||||
)
|
||||
return
|
||||
if not should_process_pr_logic(body): # Here we already updated the configuration with the repo settings
|
||||
if not should_process_pr_logic(
|
||||
body
|
||||
): # Here we already updated the configuration with the repo settings
|
||||
return {}
|
||||
commands = get_settings().get(f"github_app.{commands_conf}")
|
||||
if not commands:
|
||||
@ -398,7 +528,9 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body
|
||||
args = split_command[1:]
|
||||
other_args = update_settings_from_args(args)
|
||||
new_command = ' '.join([command] + other_args)
|
||||
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}")
|
||||
get_logger().info(
|
||||
f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}"
|
||||
)
|
||||
await agent.handle_request(api_url, new_command)
|
||||
|
||||
|
||||
|
||||
@ -18,11 +18,13 @@ NOTIFICATION_URL = "https://api.github.com/notifications"
|
||||
|
||||
async def mark_notification_as_read(headers, notification, session):
|
||||
async with session.patch(
|
||||
f"https://api.github.com/notifications/threads/{notification['id']}",
|
||||
headers=headers) as mark_read_response:
|
||||
f"https://api.github.com/notifications/threads/{notification['id']}",
|
||||
headers=headers,
|
||||
) as mark_read_response:
|
||||
if mark_read_response.status != 205:
|
||||
get_logger().error(
|
||||
f"Failed to mark notification as read. Status code: {mark_read_response.status}")
|
||||
f"Failed to mark notification as read. Status code: {mark_read_response.status}"
|
||||
)
|
||||
|
||||
|
||||
def now() -> str:
|
||||
@ -36,17 +38,21 @@ def now() -> str:
|
||||
now_utc = now_utc.replace("+00:00", "Z")
|
||||
return now_utc
|
||||
|
||||
|
||||
async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
||||
agent = PRAgent()
|
||||
success = await agent.handle_request(
|
||||
pr_url,
|
||||
rest_of_comment,
|
||||
notify=lambda: git_provider.add_eyes_reaction(comment_id)
|
||||
notify=lambda: git_provider.add_eyes_reaction(comment_id),
|
||||
)
|
||||
return success
|
||||
|
||||
|
||||
def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
||||
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider))
|
||||
return asyncio.run(
|
||||
async_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
|
||||
)
|
||||
|
||||
|
||||
def process_comment_sync(pr_url, rest_of_comment, comment_id):
|
||||
@ -55,7 +61,10 @@ def process_comment_sync(pr_url, rest_of_comment, comment_id):
|
||||
git_provider = get_git_provider()(pr_url=pr_url)
|
||||
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error processing comment: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
|
||||
async def process_comment(pr_url, rest_of_comment, comment_id):
|
||||
@ -66,22 +75,31 @@ async def process_comment(pr_url, rest_of_comment, comment_id):
|
||||
success = await agent.handle_request(
|
||||
pr_url,
|
||||
rest_of_comment,
|
||||
notify=lambda: git_provider.add_eyes_reaction(comment_id)
|
||||
notify=lambda: git_provider.add_eyes_reaction(comment_id),
|
||||
)
|
||||
get_logger().info(f"Finished processing comment for PR: {pr_url}")
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error processing comment: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
|
||||
async def is_valid_notification(notification, headers, handled_ids, session, user_id):
|
||||
try:
|
||||
if 'reason' in notification and notification['reason'] == 'mention':
|
||||
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
|
||||
if (
|
||||
'subject' in notification
|
||||
and notification['subject']['type'] == 'PullRequest'
|
||||
):
|
||||
pr_url = notification['subject']['url']
|
||||
latest_comment = notification['subject']['latest_comment_url']
|
||||
if not latest_comment or not isinstance(latest_comment, str):
|
||||
get_logger().debug(f"no latest_comment")
|
||||
return False, handled_ids
|
||||
async with session.get(latest_comment, headers=headers) as comment_response:
|
||||
async with session.get(
|
||||
latest_comment, headers=headers
|
||||
) as comment_response:
|
||||
check_prev_comments = False
|
||||
user_tag = "@" + user_id
|
||||
if comment_response.status == 200:
|
||||
@ -94,7 +112,9 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
|
||||
handled_ids.add(comment['id'])
|
||||
if 'user' in comment and 'login' in comment['user']:
|
||||
if comment['user']['login'] == user_id:
|
||||
get_logger().debug(f"comment['user']['login'] == user_id")
|
||||
get_logger().debug(
|
||||
f"comment['user']['login'] == user_id"
|
||||
)
|
||||
check_prev_comments = True
|
||||
comment_body = comment.get('body', '')
|
||||
if not comment_body:
|
||||
@ -105,15 +125,28 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
|
||||
get_logger().debug(f"user_tag not in comment_body")
|
||||
check_prev_comments = True
|
||||
else:
|
||||
get_logger().info(f"Polling, pr_url: {pr_url}",
|
||||
artifact={"comment": comment_body})
|
||||
get_logger().info(
|
||||
f"Polling, pr_url: {pr_url}",
|
||||
artifact={"comment": comment_body},
|
||||
)
|
||||
|
||||
if not check_prev_comments:
|
||||
return True, handled_ids, comment, comment_body, pr_url, user_tag
|
||||
else: # we could not find the user tag in the latest comment. Check previous comments
|
||||
return (
|
||||
True,
|
||||
handled_ids,
|
||||
comment,
|
||||
comment_body,
|
||||
pr_url,
|
||||
user_tag,
|
||||
)
|
||||
else: # we could not find the user tag in the latest comment. Check previous comments
|
||||
# get all comments in the PR
|
||||
requests_url = f"{pr_url}/comments".replace("pulls", "issues")
|
||||
comments_response = requests.get(requests_url, headers=headers)
|
||||
requests_url = f"{pr_url}/comments".replace(
|
||||
"pulls", "issues"
|
||||
)
|
||||
comments_response = requests.get(
|
||||
requests_url, headers=headers
|
||||
)
|
||||
comments = comments_response.json()[::-1]
|
||||
max_comment_to_scan = 4
|
||||
for comment in comments[:max_comment_to_scan]:
|
||||
@ -124,23 +157,37 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
|
||||
if not comment_body:
|
||||
continue
|
||||
if user_tag in comment_body:
|
||||
get_logger().info("found user tag in previous comments")
|
||||
get_logger().info(f"Polling, pr_url: {pr_url}",
|
||||
artifact={"comment": comment_body})
|
||||
return True, handled_ids, comment, comment_body, pr_url, user_tag
|
||||
get_logger().info(
|
||||
"found user tag in previous comments"
|
||||
)
|
||||
get_logger().info(
|
||||
f"Polling, pr_url: {pr_url}",
|
||||
artifact={"comment": comment_body},
|
||||
)
|
||||
return (
|
||||
True,
|
||||
handled_ids,
|
||||
comment,
|
||||
comment_body,
|
||||
pr_url,
|
||||
user_tag,
|
||||
)
|
||||
|
||||
get_logger().warning(f"Failed to fetch comments for PR: {pr_url}",
|
||||
artifact={"comments": comments})
|
||||
get_logger().warning(
|
||||
f"Failed to fetch comments for PR: {pr_url}",
|
||||
artifact={"comments": comments},
|
||||
)
|
||||
return False, handled_ids
|
||||
|
||||
return False, handled_ids
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Error processing polling notification",
|
||||
artifact={"notification": notification, "error": e})
|
||||
get_logger().exception(
|
||||
f"Error processing polling notification",
|
||||
artifact={"notification": notification, "error": e},
|
||||
)
|
||||
return False, handled_ids
|
||||
|
||||
|
||||
|
||||
async def polling_loop():
|
||||
"""
|
||||
Polls for notifications and handles them accordingly.
|
||||
@ -171,17 +218,17 @@ async def polling_loop():
|
||||
await asyncio.sleep(5)
|
||||
headers = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
params = {
|
||||
"participating": "true"
|
||||
"Authorization": f"Bearer {token}",
|
||||
}
|
||||
params = {"participating": "true"}
|
||||
if since[0]:
|
||||
params["since"] = since[0]
|
||||
if last_modified[0]:
|
||||
headers["If-Modified-Since"] = last_modified[0]
|
||||
|
||||
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response:
|
||||
async with session.get(
|
||||
NOTIFICATION_URL, headers=headers, params=params
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
if 'Last-Modified' in response.headers:
|
||||
last_modified[0] = response.headers['Last-Modified']
|
||||
@ -189,39 +236,67 @@ async def polling_loop():
|
||||
notifications = await response.json()
|
||||
if not notifications:
|
||||
continue
|
||||
get_logger().info(f"Received {len(notifications)} notifications")
|
||||
get_logger().info(
|
||||
f"Received {len(notifications)} notifications"
|
||||
)
|
||||
task_queue = deque()
|
||||
for notification in notifications:
|
||||
if not notification:
|
||||
continue
|
||||
# mark notification as read
|
||||
await mark_notification_as_read(headers, notification, session)
|
||||
await mark_notification_as_read(
|
||||
headers, notification, session
|
||||
)
|
||||
|
||||
handled_ids.add(notification['id'])
|
||||
output = await is_valid_notification(notification, headers, handled_ids, session, user_id)
|
||||
output = await is_valid_notification(
|
||||
notification, headers, handled_ids, session, user_id
|
||||
)
|
||||
if output[0]:
|
||||
_, handled_ids, comment, comment_body, pr_url, user_tag = output
|
||||
rest_of_comment = comment_body.split(user_tag)[1].strip()
|
||||
(
|
||||
_,
|
||||
handled_ids,
|
||||
comment,
|
||||
comment_body,
|
||||
pr_url,
|
||||
user_tag,
|
||||
) = output
|
||||
rest_of_comment = comment_body.split(user_tag)[
|
||||
1
|
||||
].strip()
|
||||
comment_id = comment['id']
|
||||
|
||||
# Add to the task queue
|
||||
get_logger().info(
|
||||
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}")
|
||||
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id)))
|
||||
get_logger().info(f"Queued comment processing for PR: {pr_url}")
|
||||
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}"
|
||||
)
|
||||
task_queue.append(
|
||||
(
|
||||
process_comment_sync,
|
||||
(pr_url, rest_of_comment, comment_id),
|
||||
)
|
||||
)
|
||||
get_logger().info(
|
||||
f"Queued comment processing for PR: {pr_url}"
|
||||
)
|
||||
else:
|
||||
get_logger().debug(f"Skipping comment processing for PR")
|
||||
get_logger().debug(
|
||||
f"Skipping comment processing for PR"
|
||||
)
|
||||
|
||||
max_allowed_parallel_tasks = 10
|
||||
if task_queue:
|
||||
processes = []
|
||||
for i, (func, args) in enumerate(task_queue): # Create parallel tasks
|
||||
for i, (func, args) in enumerate(
|
||||
task_queue
|
||||
): # Create parallel tasks
|
||||
p = multiprocessing.Process(target=func, args=args)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
if i > max_allowed_parallel_tasks:
|
||||
get_logger().error(
|
||||
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session")
|
||||
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session"
|
||||
)
|
||||
break
|
||||
task_queue.clear()
|
||||
|
||||
@ -230,11 +305,15 @@ async def polling_loop():
|
||||
# p.join()
|
||||
|
||||
elif response.status != 304:
|
||||
print(f"Failed to fetch notifications. Status code: {response.status}")
|
||||
print(
|
||||
f"Failed to fetch notifications. Status code: {response.status}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
get_logger().error(f"Polling exception during processing of a notification: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Polling exception during processing of a notification: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -22,20 +22,21 @@ from utils.pr_agent.secret_providers import get_secret_provider
|
||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||
router = APIRouter()
|
||||
|
||||
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||
secret_provider = (
|
||||
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||
)
|
||||
|
||||
|
||||
async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
|
||||
try:
|
||||
import requests
|
||||
headers = {
|
||||
'Private-Token': f'{gitlab_token}'
|
||||
}
|
||||
|
||||
headers = {'Private-Token': f'{gitlab_token}'}
|
||||
# API endpoint to find MRs containing the commit
|
||||
gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com')
|
||||
response = requests.get(
|
||||
f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests',
|
||||
headers=headers
|
||||
headers=headers,
|
||||
)
|
||||
merge_requests = response.json()
|
||||
if merge_requests and response.status_code == 200:
|
||||
@ -48,6 +49,7 @@ async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
|
||||
get_logger().error(f"Failed to get MR url from commit sha: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str):
|
||||
log_context["action"] = body
|
||||
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
||||
@ -58,13 +60,19 @@ async def handle_request(api_url: str, body: str, log_context: dict, sender_id:
|
||||
await PRAgent().handle_request(api_url, body)
|
||||
|
||||
|
||||
async def _perform_commands_gitlab(commands_conf: str, agent: PRAgent, api_url: str,
|
||||
log_context: dict, data: dict):
|
||||
async def _perform_commands_gitlab(
|
||||
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
|
||||
):
|
||||
apply_repo_settings(api_url)
|
||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
|
||||
if (
|
||||
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||
): # auto commands for PR, and auto feedback is disabled
|
||||
get_logger().info(
|
||||
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
|
||||
**log_context,
|
||||
)
|
||||
return
|
||||
if not should_process_pr_logic(data): # Here we already updated the configurations
|
||||
if not should_process_pr_logic(data): # Here we already updated the configurations
|
||||
return
|
||||
commands = get_settings().get(f"gitlab.{commands_conf}", {})
|
||||
get_settings().set("config.is_auto_command", True)
|
||||
@ -106,40 +114,58 @@ def should_process_pr_logic(data) -> bool:
|
||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||
if ignore_pr_users and sender:
|
||||
if sender in ignore_pr_users:
|
||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings")
|
||||
get_logger().info(
|
||||
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings"
|
||||
)
|
||||
return False
|
||||
|
||||
# logic to ignore MRs for titles, labels and source, target branches.
|
||||
ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||
ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
|
||||
ignore_mr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
||||
ignore_mr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
||||
ignore_mr_source_branches = get_settings().get(
|
||||
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
|
||||
)
|
||||
ignore_mr_target_branches = get_settings().get(
|
||||
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
|
||||
)
|
||||
|
||||
#
|
||||
if ignore_mr_source_branches:
|
||||
source_branch = data['object_attributes'].get('source_branch')
|
||||
if any(re.search(regex, source_branch) for regex in ignore_mr_source_branches):
|
||||
if any(
|
||||
re.search(regex, source_branch) for regex in ignore_mr_source_branches
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings")
|
||||
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings"
|
||||
)
|
||||
return False
|
||||
|
||||
if ignore_mr_target_branches:
|
||||
target_branch = data['object_attributes'].get('target_branch')
|
||||
if any(re.search(regex, target_branch) for regex in ignore_mr_target_branches):
|
||||
if any(
|
||||
re.search(regex, target_branch) for regex in ignore_mr_target_branches
|
||||
):
|
||||
get_logger().info(
|
||||
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings")
|
||||
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings"
|
||||
)
|
||||
return False
|
||||
|
||||
if ignore_mr_labels:
|
||||
labels = [label['title'] for label in data['object_attributes'].get('labels', [])]
|
||||
labels = [
|
||||
label['title'] for label in data['object_attributes'].get('labels', [])
|
||||
]
|
||||
if any(label in ignore_mr_labels for label in labels):
|
||||
labels_str = ", ".join(labels)
|
||||
get_logger().info(f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings")
|
||||
get_logger().info(
|
||||
f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings"
|
||||
)
|
||||
return False
|
||||
|
||||
if ignore_mr_title:
|
||||
if any(re.search(regex, title) for regex in ignore_mr_title):
|
||||
get_logger().info(f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings")
|
||||
get_logger().info(
|
||||
f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||
@ -159,29 +185,47 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
request_token = request.headers.get("X-Gitlab-Token")
|
||||
secret = secret_provider.get_secret(request_token)
|
||||
if not secret:
|
||||
get_logger().warning(f"Empty secret retrieved, request_token: {request_token}")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=jsonable_encoder({"message": "unauthorized"}))
|
||||
get_logger().warning(
|
||||
f"Empty secret retrieved, request_token: {request_token}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=jsonable_encoder({"message": "unauthorized"}),
|
||||
)
|
||||
try:
|
||||
secret_dict = json.loads(secret)
|
||||
gitlab_token = secret_dict["gitlab_token"]
|
||||
log_context["token_id"] = secret_dict.get("token_name", secret_dict.get("id", "unknown"))
|
||||
log_context["token_id"] = secret_dict.get(
|
||||
"token_name", secret_dict.get("id", "unknown")
|
||||
)
|
||||
context["settings"].gitlab.personal_access_token = gitlab_token
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to validate secret {request_token}: {e}")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=jsonable_encoder({"message": "unauthorized"}),
|
||||
)
|
||||
elif get_settings().get("GITLAB.SHARED_SECRET"):
|
||||
secret = get_settings().get("GITLAB.SHARED_SECRET")
|
||||
if not request.headers.get("X-Gitlab-Token") == secret:
|
||||
get_logger().error("Failed to validate secret")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=jsonable_encoder({"message": "unauthorized"}),
|
||||
)
|
||||
else:
|
||||
get_logger().error("Failed to validate secret")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=jsonable_encoder({"message": "unauthorized"}),
|
||||
)
|
||||
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||
if not gitlab_token:
|
||||
get_logger().error("No gitlab token found")
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=jsonable_encoder({"message": "unauthorized"}),
|
||||
)
|
||||
|
||||
get_logger().info("GitLab data", artifact=data)
|
||||
sender = data.get("user", {}).get("username", "unknown")
|
||||
@ -189,31 +233,49 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
|
||||
# ignore bot users
|
||||
if is_bot_user(data):
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
if data.get('event_type') != 'note': # not a comment
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}),
|
||||
)
|
||||
if data.get('event_type') != 'note': # not a comment
|
||||
# ignore MRs based on title, labels, source and target branches
|
||||
if not should_process_pr_logic(data):
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}),
|
||||
)
|
||||
|
||||
log_context["sender"] = sender
|
||||
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']:
|
||||
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get(
|
||||
'action'
|
||||
) in ['open', 'reopen']:
|
||||
title = data['object_attributes'].get('title')
|
||||
url = data['object_attributes'].get('url')
|
||||
draft = data['object_attributes'].get('draft')
|
||||
get_logger().info(f"New merge request: {url}")
|
||||
if draft:
|
||||
get_logger().info(f"Skipping draft MR: {url}")
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}),
|
||||
)
|
||||
|
||||
await _perform_commands_gitlab("pr_commands", PRAgent(), url, log_context, data)
|
||||
elif data.get('object_kind') == 'note' and data.get('event_type') == 'note': # comment on MR
|
||||
await _perform_commands_gitlab(
|
||||
"pr_commands", PRAgent(), url, log_context, data
|
||||
)
|
||||
elif (
|
||||
data.get('object_kind') == 'note' and data.get('event_type') == 'note'
|
||||
): # comment on MR
|
||||
if 'merge_request' in data:
|
||||
mr = data['merge_request']
|
||||
url = mr.get('url')
|
||||
|
||||
get_logger().info(f"A comment has been added to a merge request: {url}")
|
||||
body = data.get('object_attributes', {}).get('note')
|
||||
if data.get('object_attributes', {}).get('type') == 'DiffNote' and '/ask' in body: # /ask_line
|
||||
if (
|
||||
data.get('object_attributes', {}).get('type') == 'DiffNote'
|
||||
and '/ask' in body
|
||||
): # /ask_line
|
||||
body = handle_ask_line(body, data)
|
||||
|
||||
await handle_request(url, body, log_context, sender_id)
|
||||
@ -221,30 +283,44 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||
try:
|
||||
project_id = data['project_id']
|
||||
commit_sha = data['checkout_sha']
|
||||
url = await get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id)
|
||||
url = await get_mr_url_from_commit_sha(
|
||||
commit_sha, gitlab_token, project_id
|
||||
)
|
||||
if not url:
|
||||
get_logger().info(f"No MR found for commit: {commit_sha}")
|
||||
return JSONResponse(status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}),
|
||||
)
|
||||
|
||||
# we need first to apply_repo_settings
|
||||
apply_repo_settings(url)
|
||||
commands_on_push = get_settings().get(f"gitlab.push_commands", {})
|
||||
handle_push_trigger = get_settings().get(f"gitlab.handle_push_trigger", False)
|
||||
handle_push_trigger = get_settings().get(
|
||||
f"gitlab.handle_push_trigger", False
|
||||
)
|
||||
if not commands_on_push or not handle_push_trigger:
|
||||
get_logger().info("Push event, but no push commands found or push trigger is disabled")
|
||||
return JSONResponse(status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}))
|
||||
get_logger().info(
|
||||
"Push event, but no push commands found or push trigger is disabled"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=jsonable_encoder({"message": "success"}),
|
||||
)
|
||||
|
||||
get_logger().debug(f'A push event has been received: {url}')
|
||||
await _perform_commands_gitlab("push_commands", PRAgent(), url, log_context, data)
|
||||
await _perform_commands_gitlab(
|
||||
"push_commands", PRAgent(), url, log_context, data
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to handle push event: {e}")
|
||||
|
||||
background_tasks.add_task(inner, request_json)
|
||||
end_time = datetime.now()
|
||||
get_logger().info(f"Processing time: {end_time - start_time}", request=request_json)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})
|
||||
)
|
||||
|
||||
|
||||
def handle_ask_line(body, data):
|
||||
@ -271,6 +347,7 @@ def handle_ask_line(body, data):
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||
if not gitlab_url:
|
||||
raise ValueError("GITLAB.URL is not set")
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
class HelpMessage:
|
||||
@staticmethod
|
||||
def get_general_commands_text():
|
||||
commands_text = "> - **/review**: Request a review of your Pull Request. \n" \
|
||||
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n" \
|
||||
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" \
|
||||
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n" \
|
||||
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n" \
|
||||
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" \
|
||||
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" \
|
||||
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" \
|
||||
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" \
|
||||
">To list the possible configuration parameters, add a **/config** comment. \n"
|
||||
return commands_text
|
||||
|
||||
commands_text = (
|
||||
"> - **/review**: Request a review of your Pull Request. \n"
|
||||
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n"
|
||||
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n"
|
||||
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n"
|
||||
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n"
|
||||
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n"
|
||||
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n"
|
||||
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n"
|
||||
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n"
|
||||
">To list the possible configuration parameters, add a **/config** comment. \n"
|
||||
)
|
||||
return commands_text
|
||||
|
||||
@staticmethod
|
||||
def get_general_bot_help_text():
|
||||
@ -21,10 +22,12 @@ class HelpMessage:
|
||||
|
||||
@staticmethod
|
||||
def get_review_usage_guide():
|
||||
output ="**Overview:**\n"
|
||||
output +=("The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
|
||||
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n")
|
||||
output +="""\
|
||||
output = "**Overview:**\n"
|
||||
output += (
|
||||
"The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
|
||||
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n"
|
||||
)
|
||||
output += """\
|
||||
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template:
|
||||
```
|
||||
/review --pr_reviewer.some_config1=... --pr_reviewer.some_config2=...
|
||||
@ -41,8 +44,6 @@ some_config2=...
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_describe_usage_guide():
|
||||
output = "**Overview:**\n"
|
||||
@ -137,7 +138,6 @@ Use triple quotes to write multi-line instructions. Use bullet points to make th
|
||||
'''
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
|
||||
# general
|
||||
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
|
||||
output += HelpMessage.get_general_bot_help_text()
|
||||
@ -175,7 +175,6 @@ You can ask questions about the entire PR, about specific code lines, or about a
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_improve_usage_guide():
|
||||
output = "**Overview:**\n"
|
||||
|
||||
@ -18,8 +18,12 @@ def verify_signature(payload_body, secret_token, signature_header):
|
||||
signature_header: header received from GitHub (x-hub-signature-256)
|
||||
"""
|
||||
if not signature_header:
|
||||
raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!")
|
||||
hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256)
|
||||
raise HTTPException(
|
||||
status_code=403, detail="x-hub-signature-256 header is missing!"
|
||||
)
|
||||
hash_object = hmac.new(
|
||||
secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256
|
||||
)
|
||||
expected_signature = "sha256=" + hash_object.hexdigest()
|
||||
if not hmac.compare_digest(expected_signature, signature_header):
|
||||
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
|
||||
@ -27,6 +31,7 @@ def verify_signature(payload_body, secret_token, signature_header):
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when the git provider API rate limit has been exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -66,7 +71,11 @@ class DefaultDictWithTimeout(defaultdict):
|
||||
request_time = self.__time()
|
||||
if request_time - self.__last_refresh > self.__refresh_interval:
|
||||
return
|
||||
to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl]
|
||||
to_delete = [
|
||||
key
|
||||
for key, key_time in self.__key_times.items()
|
||||
if request_time - key_time > self.__ttl
|
||||
]
|
||||
for key in to_delete:
|
||||
del self[key]
|
||||
self.__last_refresh = request_time
|
||||
|
||||
@ -17,9 +17,13 @@ from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRAddDocs:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
cli_mode=False,
|
||||
args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
):
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
@ -39,13 +43,16 @@ class PRAddDocs:
|
||||
"diff": "", # empty diff for initial calculation
|
||||
"extra_instructions": get_settings().pr_add_docs.extra_instructions,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
'docs_for_language': get_docs_for_language(self.main_language,
|
||||
get_settings().pr_add_docs.docs_style),
|
||||
'docs_for_language': get_docs_for_language(
|
||||
self.main_language, get_settings().pr_add_docs.docs_style
|
||||
),
|
||||
}
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_add_docs_prompt.system,
|
||||
get_settings().pr_add_docs_prompt.user)
|
||||
self.token_handler = TokenHandler(
|
||||
self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_add_docs_prompt.system,
|
||||
get_settings().pr_add_docs_prompt.user,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
@ -66,16 +73,20 @@ class PRAddDocs:
|
||||
get_logger().info('Pushing inline code documentation...')
|
||||
self.push_inline_docs(data)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to generate code documentation for PR, error: {e}")
|
||||
get_logger().error(
|
||||
f"Failed to generate code documentation for PR, error: {e}"
|
||||
)
|
||||
|
||||
async def _prepare_prediction(self, model: str):
|
||||
get_logger().info('Getting PR diff...')
|
||||
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=False)
|
||||
self.patches_diff = get_pr_diff(
|
||||
self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=False,
|
||||
)
|
||||
|
||||
get_logger().info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
@ -84,13 +95,21 @@ class PRAddDocs:
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().pr_add_docs_prompt.system
|
||||
).render(variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().pr_add_docs_prompt.user
|
||||
).render(variables)
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
||||
get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -105,7 +124,9 @@ class PRAddDocs:
|
||||
docs = []
|
||||
|
||||
if not data['Code Documentation']:
|
||||
return self.git_provider.publish_comment('No code documentation found to improve this PR.')
|
||||
return self.git_provider.publish_comment(
|
||||
'No code documentation found to improve this PR.'
|
||||
)
|
||||
|
||||
for d in data['Code Documentation']:
|
||||
try:
|
||||
@ -116,32 +137,59 @@ class PRAddDocs:
|
||||
documentation = d['documentation']
|
||||
doc_placement = d['doc placement'].strip()
|
||||
if documentation:
|
||||
new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement,
|
||||
add_original_line=True)
|
||||
new_code_snippet = self.dedent_code(
|
||||
relevant_file,
|
||||
relevant_line,
|
||||
documentation,
|
||||
doc_placement,
|
||||
add_original_line=True,
|
||||
)
|
||||
|
||||
body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```"
|
||||
docs.append({'body': body, 'relevant_file': relevant_file,
|
||||
'relevant_lines_start': relevant_line,
|
||||
'relevant_lines_end': relevant_line})
|
||||
body = (
|
||||
f"**Suggestion:** Proposed documentation\n```suggestion\n"
|
||||
+ new_code_snippet
|
||||
+ "\n```"
|
||||
)
|
||||
docs.append(
|
||||
{
|
||||
'body': body,
|
||||
'relevant_file': relevant_file,
|
||||
'relevant_lines_start': relevant_line,
|
||||
'relevant_lines_end': relevant_line,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not parse code docs: {d}")
|
||||
|
||||
is_successful = self.git_provider.publish_code_suggestions(docs)
|
||||
if not is_successful:
|
||||
get_logger().info("Failed to publish code docs, trying to publish each docs separately")
|
||||
get_logger().info(
|
||||
"Failed to publish code docs, trying to publish each docs separately"
|
||||
)
|
||||
for doc_suggestion in docs:
|
||||
self.git_provider.publish_code_suggestions([doc_suggestion])
|
||||
|
||||
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after',
|
||||
add_original_line=False):
|
||||
def dedent_code(
|
||||
self,
|
||||
relevant_file,
|
||||
relevant_lines_start,
|
||||
new_code_snippet,
|
||||
doc_placement='after',
|
||||
add_original_line=False,
|
||||
):
|
||||
try: # dedent code snippet
|
||||
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
|
||||
self.diff_files = (
|
||||
self.git_provider.diff_files
|
||||
if self.git_provider.diff_files
|
||||
else self.git_provider.get_diff_files()
|
||||
)
|
||||
original_initial_line = None
|
||||
for file in self.diff_files:
|
||||
if file.filename.strip() == relevant_file:
|
||||
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
|
||||
original_initial_line = file.head_file.splitlines()[
|
||||
relevant_lines_start - 1
|
||||
]
|
||||
break
|
||||
if original_initial_line:
|
||||
if doc_placement == 'after':
|
||||
@ -150,18 +198,28 @@ class PRAddDocs:
|
||||
line = original_initial_line
|
||||
suggested_initial_line = new_code_snippet.splitlines()[0]
|
||||
original_initial_spaces = len(line) - len(line.lstrip())
|
||||
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
|
||||
suggested_initial_spaces = len(suggested_initial_line) - len(
|
||||
suggested_initial_line.lstrip()
|
||||
)
|
||||
delta_spaces = original_initial_spaces - suggested_initial_spaces
|
||||
if delta_spaces > 0:
|
||||
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
|
||||
new_code_snippet = textwrap.indent(
|
||||
new_code_snippet, delta_spaces * " "
|
||||
).rstrip('\n')
|
||||
if add_original_line:
|
||||
if doc_placement == 'after':
|
||||
new_code_snippet = original_initial_line + "\n" + new_code_snippet
|
||||
new_code_snippet = (
|
||||
original_initial_line + "\n" + new_code_snippet
|
||||
)
|
||||
else:
|
||||
new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line
|
||||
new_code_snippet = (
|
||||
new_code_snippet.rstrip() + "\n" + original_initial_line
|
||||
)
|
||||
except Exception as e:
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
|
||||
get_logger().info(
|
||||
f"Could not dedent code snippet for file {relevant_file}, error: {e}"
|
||||
)
|
||||
|
||||
return new_code_snippet
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ class PRConfig:
|
||||
"""
|
||||
The PRConfig class is responsible for listing all configuration options available for the user.
|
||||
"""
|
||||
|
||||
def __init__(self, pr_url: str, args=None, ai_handler=None):
|
||||
"""
|
||||
Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request.
|
||||
@ -34,20 +35,43 @@ class PRConfig:
|
||||
conf_settings = Dynaconf(settings_files=[conf_file])
|
||||
configuration_headers = [header.lower() for header in conf_settings.keys()]
|
||||
relevant_configs = {
|
||||
header: configs for header, configs in get_settings().to_dict().items()
|
||||
if (header.lower().startswith("pr_") or header.lower().startswith("config")) and header.lower() in configuration_headers
|
||||
header: configs
|
||||
for header, configs in get_settings().to_dict().items()
|
||||
if (header.lower().startswith("pr_") or header.lower().startswith("config"))
|
||||
and header.lower() in configuration_headers
|
||||
}
|
||||
|
||||
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
|
||||
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS',
|
||||
'APP_NAME', 'PERSONAL_ACCESS_TOKEN', 'shared_secret', 'key', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'user_token',
|
||||
'private_key', 'private_key_id', 'client_id', 'client_secret', 'token', 'bearer_token']
|
||||
skip_keys = [
|
||||
'ai_disclaimer',
|
||||
'ai_disclaimer_title',
|
||||
'ANALYTICS_FOLDER',
|
||||
'secret_provider',
|
||||
"skip_keys",
|
||||
"app_id",
|
||||
"redirect",
|
||||
'trial_prefix_message',
|
||||
'no_eligible_message',
|
||||
'identity_provider',
|
||||
'ALLOWED_REPOS',
|
||||
'APP_NAME',
|
||||
'PERSONAL_ACCESS_TOKEN',
|
||||
'shared_secret',
|
||||
'key',
|
||||
'AWS_ACCESS_KEY_ID',
|
||||
'AWS_SECRET_ACCESS_KEY',
|
||||
'user_token',
|
||||
'private_key',
|
||||
'private_key_id',
|
||||
'client_id',
|
||||
'client_secret',
|
||||
'token',
|
||||
'bearer_token',
|
||||
]
|
||||
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
|
||||
if extra_skip_keys:
|
||||
skip_keys.extend(extra_skip_keys)
|
||||
skip_keys_lower = [key.lower() for key in skip_keys]
|
||||
|
||||
|
||||
markdown_text = "<details> <summary><strong>🛠️ PR-Agent Configurations:</strong></summary> \n\n"
|
||||
markdown_text += f"\n\n```yaml\n\n"
|
||||
for header, configs in relevant_configs.items():
|
||||
@ -61,5 +85,7 @@ class PRConfig:
|
||||
markdown_text += " "
|
||||
markdown_text += "\n```"
|
||||
markdown_text += "\n</details>\n"
|
||||
get_logger().info(f"Possible Configurations outputted to PR comment", artifact=markdown_text)
|
||||
get_logger().info(
|
||||
f"Possible Configurations outputted to PR comment", artifact=markdown_text
|
||||
)
|
||||
return markdown_text
|
||||
|
||||
@ -10,27 +10,38 @@ from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from utils.pr_agent.algo.pr_processing import (OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
|
||||
get_pr_diff,
|
||||
get_pr_diff_multiple_patchs,
|
||||
retry_with_fallback_models)
|
||||
from utils.pr_agent.algo.pr_processing import (
|
||||
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
|
||||
get_pr_diff,
|
||||
get_pr_diff_multiple_patchs,
|
||||
retry_with_fallback_models,
|
||||
)
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens,
|
||||
get_max_tokens, get_user_labels, load_yaml,
|
||||
set_custom_labels,
|
||||
show_relevant_configurations)
|
||||
from utils.pr_agent.algo.utils import (
|
||||
ModelType,
|
||||
PRDescriptionHeader,
|
||||
clip_tokens,
|
||||
get_max_tokens,
|
||||
get_user_labels,
|
||||
load_yaml,
|
||||
set_custom_labels,
|
||||
show_relevant_configurations,
|
||||
)
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import (GithubProvider, get_git_provider_with_context)
|
||||
from utils.pr_agent.git_providers import GithubProvider, get_git_provider_with_context
|
||||
from utils.pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
from utils.pr_agent.log import get_logger
|
||||
from utils.pr_agent.servers.help import HelpMessage
|
||||
from utils.pr_agent.tools.ticket_pr_compliance_check import (
|
||||
extract_and_cache_pr_tickets)
|
||||
from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
|
||||
|
||||
|
||||
class PRDescription:
|
||||
def __init__(self, pr_url: str, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
):
|
||||
"""
|
||||
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
||||
using an AI model.
|
||||
@ -44,11 +55,22 @@ class PRDescription:
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
self.pr_id = self.git_provider.get_pr_id()
|
||||
self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"]
|
||||
self.keys_fix = [
|
||||
"filename:",
|
||||
"language:",
|
||||
"changes_summary:",
|
||||
"changes_title:",
|
||||
"description:",
|
||||
"title:",
|
||||
]
|
||||
|
||||
if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported(
|
||||
"gfm_markdown"):
|
||||
get_logger().debug(f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported.")
|
||||
if (
|
||||
get_settings().pr_description.enable_semantic_files_types
|
||||
and not self.git_provider.is_supported("gfm_markdown")
|
||||
):
|
||||
get_logger().debug(
|
||||
f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported."
|
||||
)
|
||||
get_settings().pr_description.enable_semantic_files_types = False
|
||||
|
||||
# Initialize the AI handler
|
||||
@ -56,7 +78,9 @@ class PRDescription:
|
||||
self.ai_handler.main_pr_language = self.main_pr_language
|
||||
|
||||
# Initialize the variables dictionary
|
||||
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get("collapsible_file_list_threshold", 8)
|
||||
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get(
|
||||
"collapsible_file_list_threshold", 8
|
||||
)
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
"branch": self.git_provider.get_pr_branch(),
|
||||
@ -69,8 +93,11 @@ class PRDescription:
|
||||
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
|
||||
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
|
||||
"related_tickets": "",
|
||||
"include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
|
||||
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
|
||||
"include_file_summary_changes": len(self.git_provider.get_diff_files())
|
||||
<= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
|
||||
'duplicate_prompt_examples': get_settings().config.get(
|
||||
'duplicate_prompt_examples', False
|
||||
),
|
||||
}
|
||||
|
||||
self.user_description = self.git_provider.get_user_description()
|
||||
@ -91,10 +118,14 @@ class PRDescription:
|
||||
async def run(self):
|
||||
try:
|
||||
get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}")
|
||||
relevant_configs = {'pr_description': dict(get_settings().pr_description),
|
||||
'config': dict(get_settings().config)}
|
||||
relevant_configs = {
|
||||
'pr_description': dict(get_settings().pr_description),
|
||||
'config': dict(get_settings().config),
|
||||
}
|
||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
|
||||
if get_settings().config.publish_output and not get_settings().config.get(
|
||||
'is_auto_command', False
|
||||
):
|
||||
self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True)
|
||||
|
||||
# ticket extraction if exists
|
||||
@ -119,40 +150,73 @@ class PRDescription:
|
||||
get_logger().debug(f"Publishing labels disabled")
|
||||
|
||||
if get_settings().pr_description.use_description_markers:
|
||||
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer_with_markers()
|
||||
(
|
||||
pr_title,
|
||||
pr_body,
|
||||
changes_walkthrough,
|
||||
pr_file_changes,
|
||||
) = self._prepare_pr_answer_with_markers()
|
||||
else:
|
||||
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer()
|
||||
if not self.git_provider.is_supported(
|
||||
"publish_file_comments") or not get_settings().pr_description.inline_file_summary:
|
||||
(
|
||||
pr_title,
|
||||
pr_body,
|
||||
changes_walkthrough,
|
||||
pr_file_changes,
|
||||
) = self._prepare_pr_answer()
|
||||
if (
|
||||
not self.git_provider.is_supported("publish_file_comments")
|
||||
or not get_settings().pr_description.inline_file_summary
|
||||
):
|
||||
pr_body += "\n\n" + changes_walkthrough
|
||||
get_logger().debug("PR output", artifact={"title": pr_title, "body": pr_body})
|
||||
get_logger().debug(
|
||||
"PR output", artifact={"title": pr_title, "body": pr_body}
|
||||
)
|
||||
|
||||
# Add help text if gfm_markdown is supported
|
||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_description.enable_help_text:
|
||||
if (
|
||||
self.git_provider.is_supported("gfm_markdown")
|
||||
and get_settings().pr_description.enable_help_text
|
||||
):
|
||||
pr_body += "<hr>\n\n<details> <summary><strong>✨ 工具使用指南:</strong></summary><hr> \n\n"
|
||||
pr_body += HelpMessage.get_describe_usage_guide()
|
||||
pr_body += "\n</details>\n"
|
||||
elif get_settings().pr_description.enable_help_comment and self.git_provider.is_supported("gfm_markdown"):
|
||||
elif (
|
||||
get_settings().pr_description.enable_help_comment
|
||||
and self.git_provider.is_supported("gfm_markdown")
|
||||
):
|
||||
if isinstance(self.git_provider, GithubProvider):
|
||||
pr_body += ('\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> '
|
||||
'关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 '
|
||||
'<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> '
|
||||
'了解更多.</li></details>')
|
||||
else: # gitlab
|
||||
pr_body += ("\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 "
|
||||
"关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 "
|
||||
"<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>")
|
||||
pr_body += (
|
||||
'\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> '
|
||||
'关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 '
|
||||
'<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> '
|
||||
'了解更多.</li></details>'
|
||||
)
|
||||
else: # gitlab
|
||||
pr_body += (
|
||||
"\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 "
|
||||
"关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 "
|
||||
"<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>"
|
||||
)
|
||||
# elif get_settings().pr_description.enable_help_comment:
|
||||
# pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information'
|
||||
|
||||
# Output the relevant configurations if enabled
|
||||
if get_settings().get('config', {}).get('output_relevant_configurations', False):
|
||||
pr_body += show_relevant_configurations(relevant_section='pr_description')
|
||||
if (
|
||||
get_settings()
|
||||
.get('config', {})
|
||||
.get('output_relevant_configurations', False)
|
||||
):
|
||||
pr_body += show_relevant_configurations(
|
||||
relevant_section='pr_description'
|
||||
)
|
||||
|
||||
if get_settings().config.publish_output:
|
||||
|
||||
# publish labels
|
||||
if get_settings().pr_description.publish_labels and pr_labels and self.git_provider.is_supported("get_labels"):
|
||||
if (
|
||||
get_settings().pr_description.publish_labels
|
||||
and pr_labels
|
||||
and self.git_provider.is_supported("get_labels")
|
||||
):
|
||||
original_labels = self.git_provider.get_pr_labels(update=True)
|
||||
get_logger().debug(f"original labels", artifact=original_labels)
|
||||
user_labels = get_user_labels(original_labels)
|
||||
@ -165,20 +229,29 @@ class PRDescription:
|
||||
|
||||
# publish description
|
||||
if get_settings().pr_description.publish_description_as_comment:
|
||||
full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
|
||||
if get_settings().pr_description.publish_description_as_comment_persistent:
|
||||
self.git_provider.publish_persistent_comment(full_markdown_description,
|
||||
initial_header="## Title",
|
||||
update_header=True,
|
||||
name="describe",
|
||||
final_update_message=False, )
|
||||
full_markdown_description = (
|
||||
f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
|
||||
)
|
||||
if (
|
||||
get_settings().pr_description.publish_description_as_comment_persistent
|
||||
):
|
||||
self.git_provider.publish_persistent_comment(
|
||||
full_markdown_description,
|
||||
initial_header="## Title",
|
||||
update_header=True,
|
||||
name="describe",
|
||||
final_update_message=False,
|
||||
)
|
||||
else:
|
||||
self.git_provider.publish_comment(full_markdown_description)
|
||||
else:
|
||||
self.git_provider.publish_description(pr_title, pr_body)
|
||||
|
||||
# publish final update message
|
||||
if (get_settings().pr_description.final_update_message and not get_settings().config.get('is_auto_command', False)):
|
||||
if (
|
||||
get_settings().pr_description.final_update_message
|
||||
and not get_settings().config.get('is_auto_command', False)
|
||||
):
|
||||
latest_commit_url = self.git_provider.get_latest_commit_url()
|
||||
if latest_commit_url:
|
||||
pr_url = self.git_provider.get_pr_url()
|
||||
@ -186,22 +259,40 @@ class PRDescription:
|
||||
self.git_provider.publish_comment(update_comment)
|
||||
self.git_provider.remove_initial_comment()
|
||||
else:
|
||||
get_logger().info('PR description, but not published since publish_output is False.')
|
||||
get_logger().info(
|
||||
'PR description, but not published since publish_output is False.'
|
||||
)
|
||||
get_settings().data = {"artifact": pr_body}
|
||||
return
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error generating PR description {self.pr_id}: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error generating PR description {self.pr_id}: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
async def _prepare_prediction(self, model: str) -> None:
|
||||
if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
|
||||
get_logger().info("Markers were enabled, but user description does not contain markers. skipping AI prediction")
|
||||
if (
|
||||
get_settings().pr_description.use_description_markers
|
||||
and 'pr_agent:' not in self.user_description
|
||||
):
|
||||
get_logger().info(
|
||||
"Markers were enabled, but user description does not contain markers. skipping AI prediction"
|
||||
)
|
||||
return None
|
||||
|
||||
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
|
||||
output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
|
||||
large_pr_handling = (
|
||||
get_settings().pr_description.enable_large_pr_handling
|
||||
and "pr_description_only_files_prompts" in get_settings()
|
||||
)
|
||||
output = get_pr_diff(
|
||||
self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
large_pr_handling=large_pr_handling,
|
||||
return_remaining_files=True,
|
||||
)
|
||||
if isinstance(output, tuple):
|
||||
patches_diff, remaining_files_list = output
|
||||
else:
|
||||
@ -213,14 +304,18 @@ class PRDescription:
|
||||
if patches_diff:
|
||||
# generate the prediction
|
||||
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
||||
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
|
||||
self.prediction = await self._get_prediction(
|
||||
model, patches_diff, prompt="pr_description_prompt"
|
||||
)
|
||||
|
||||
# extend the prediction with additional files not shown
|
||||
if get_settings().pr_description.enable_semantic_files_types:
|
||||
self.prediction = await self.extend_uncovered_files(self.prediction)
|
||||
else:
|
||||
get_logger().error(f"Error getting PR diff {self.pr_id}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error getting PR diff {self.pr_id}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
self.prediction = None
|
||||
else:
|
||||
# get the diff in multiple patches, with the token handler only for the files prompt
|
||||
@ -231,9 +326,16 @@ class PRDescription:
|
||||
get_settings().pr_description_only_files_prompts.system,
|
||||
get_settings().pr_description_only_files_prompts.user,
|
||||
)
|
||||
(patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict,
|
||||
files_in_patches_list) = get_pr_diff_multiple_patchs(
|
||||
self.git_provider, token_handler_only_files_prompt, model)
|
||||
(
|
||||
patches_compressed_list,
|
||||
total_tokens_list,
|
||||
deleted_files_list,
|
||||
remaining_files_list,
|
||||
file_dict,
|
||||
files_in_patches_list,
|
||||
) = get_pr_diff_multiple_patchs(
|
||||
self.git_provider, token_handler_only_files_prompt, model
|
||||
)
|
||||
|
||||
# get the files prediction for each patch
|
||||
if not get_settings().pr_description.async_ai_calls:
|
||||
@ -241,8 +343,9 @@ class PRDescription:
|
||||
for i, patches in enumerate(patches_compressed_list): # sync calls
|
||||
patches_diff = "\n".join(patches)
|
||||
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
||||
prediction_files = await self._get_prediction(model, patches_diff,
|
||||
prompt="pr_description_only_files_prompts")
|
||||
prediction_files = await self._get_prediction(
|
||||
model, patches_diff, prompt="pr_description_only_files_prompts"
|
||||
)
|
||||
results.append(prediction_files)
|
||||
else: # async calls
|
||||
tasks = []
|
||||
@ -251,34 +354,52 @@ class PRDescription:
|
||||
patches_diff = "\n".join(patches)
|
||||
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
||||
task = asyncio.create_task(
|
||||
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
|
||||
self._get_prediction(
|
||||
model,
|
||||
patches_diff,
|
||||
prompt="pr_description_only_files_prompts",
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
file_description_str_list = []
|
||||
for i, result in enumerate(results):
|
||||
prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
|
||||
if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'):
|
||||
prediction_files = prediction_files.removeprefix('pr_files:').strip()
|
||||
prediction_files = (
|
||||
result.strip().removeprefix('```yaml').strip('`').strip()
|
||||
)
|
||||
if load_yaml(
|
||||
prediction_files, keys_fix_yaml=self.keys_fix
|
||||
) and prediction_files.startswith('pr_files'):
|
||||
prediction_files = prediction_files.removeprefix(
|
||||
'pr_files:'
|
||||
).strip()
|
||||
file_description_str_list.append(prediction_files)
|
||||
else:
|
||||
get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files")
|
||||
get_logger().debug(
|
||||
f"failed to generate predictions in iteration {i + 1} for describe files"
|
||||
)
|
||||
|
||||
# generate files_walkthrough string, with proper token handling
|
||||
token_handler_only_description_prompt = TokenHandler(
|
||||
self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_description_only_description_prompts.system,
|
||||
get_settings().pr_description_only_description_prompts.user)
|
||||
get_settings().pr_description_only_description_prompts.user,
|
||||
)
|
||||
files_walkthrough = "\n".join(file_description_str_list)
|
||||
files_walkthrough_prompt = copy.deepcopy(files_walkthrough)
|
||||
MAX_EXTRA_FILES_TO_PROMPT = 50
|
||||
if remaining_files_list:
|
||||
files_walkthrough_prompt += "\n\nNo more token budget. Additional unprocessed files:"
|
||||
files_walkthrough_prompt += (
|
||||
"\n\nNo more token budget. Additional unprocessed files:"
|
||||
)
|
||||
for i, file in enumerate(remaining_files_list):
|
||||
files_walkthrough_prompt += f"\n- {file}"
|
||||
if i >= MAX_EXTRA_FILES_TO_PROMPT:
|
||||
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
|
||||
get_logger().debug(
|
||||
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
|
||||
)
|
||||
files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
|
||||
break
|
||||
if deleted_files_list:
|
||||
@ -286,32 +407,57 @@ class PRDescription:
|
||||
for i, file in enumerate(deleted_files_list):
|
||||
files_walkthrough_prompt += f"\n- {file}"
|
||||
if i >= MAX_EXTRA_FILES_TO_PROMPT:
|
||||
get_logger().debug(f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
|
||||
get_logger().debug(
|
||||
f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
|
||||
)
|
||||
files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
|
||||
break
|
||||
tokens_files_walkthrough = len(
|
||||
token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt))
|
||||
total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough
|
||||
token_handler_only_description_prompt.encoder.encode(
|
||||
files_walkthrough_prompt
|
||||
)
|
||||
)
|
||||
total_tokens = (
|
||||
token_handler_only_description_prompt.prompt_tokens
|
||||
+ tokens_files_walkthrough
|
||||
)
|
||||
max_tokens_model = get_max_tokens(model)
|
||||
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
# clip files_walkthrough to git the tokens within the limit
|
||||
files_walkthrough_prompt = clip_tokens(files_walkthrough_prompt,
|
||||
max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens,
|
||||
num_input_tokens=tokens_files_walkthrough)
|
||||
files_walkthrough_prompt = clip_tokens(
|
||||
files_walkthrough_prompt,
|
||||
max_tokens_model
|
||||
- OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
|
||||
- token_handler_only_description_prompt.prompt_tokens,
|
||||
num_input_tokens=tokens_files_walkthrough,
|
||||
)
|
||||
|
||||
# PR header inference
|
||||
get_logger().debug(f"PR diff only description", artifact=files_walkthrough_prompt)
|
||||
prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough_prompt,
|
||||
prompt="pr_description_only_description_prompts")
|
||||
prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
|
||||
get_logger().debug(
|
||||
f"PR diff only description", artifact=files_walkthrough_prompt
|
||||
)
|
||||
prediction_headers = await self._get_prediction(
|
||||
model,
|
||||
patches_diff=files_walkthrough_prompt,
|
||||
prompt="pr_description_only_description_prompts",
|
||||
)
|
||||
prediction_headers = (
|
||||
prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
|
||||
)
|
||||
|
||||
# extend the tables with the files not shown
|
||||
files_walkthrough_extended = await self.extend_uncovered_files(files_walkthrough)
|
||||
files_walkthrough_extended = await self.extend_uncovered_files(
|
||||
files_walkthrough
|
||||
)
|
||||
|
||||
# final processing
|
||||
self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
|
||||
self.prediction = (
|
||||
prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
|
||||
)
|
||||
if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix):
|
||||
get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}")
|
||||
get_logger().error(
|
||||
f"Error getting valid YAML in large PR handling for describe {self.pr_id}"
|
||||
)
|
||||
if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix):
|
||||
get_logger().debug(f"Using only headers for describe {self.pr_id}")
|
||||
self.prediction = prediction_headers
|
||||
@ -321,12 +467,17 @@ class PRDescription:
|
||||
prediction = original_prediction
|
||||
|
||||
# get the original prediction filenames
|
||||
original_prediction_loaded = load_yaml(original_prediction, keys_fix_yaml=self.keys_fix)
|
||||
original_prediction_loaded = load_yaml(
|
||||
original_prediction, keys_fix_yaml=self.keys_fix
|
||||
)
|
||||
if isinstance(original_prediction_loaded, list):
|
||||
original_prediction_dict = {"pr_files": original_prediction_loaded}
|
||||
else:
|
||||
original_prediction_dict = original_prediction_loaded
|
||||
filenames_predicted = [file['filename'].strip() for file in original_prediction_dict.get('pr_files', [])]
|
||||
filenames_predicted = [
|
||||
file['filename'].strip()
|
||||
for file in original_prediction_dict.get('pr_files', [])
|
||||
]
|
||||
|
||||
# extend the prediction with additional files not included in the original prediction
|
||||
pr_files = self.git_provider.get_diff_files()
|
||||
@ -349,7 +500,9 @@ class PRDescription:
|
||||
additional files
|
||||
"""
|
||||
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
||||
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}")
|
||||
get_logger().debug(
|
||||
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}"
|
||||
)
|
||||
break
|
||||
|
||||
extra_file_yaml = f"""\
|
||||
@ -364,10 +517,18 @@ class PRDescription:
|
||||
|
||||
# merge the two dictionaries
|
||||
if counter_extra_files > 0:
|
||||
get_logger().info(f"Adding {counter_extra_files} unprocessed extra files to table prediction")
|
||||
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
|
||||
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
|
||||
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
|
||||
get_logger().info(
|
||||
f"Adding {counter_extra_files} unprocessed extra files to table prediction"
|
||||
)
|
||||
prediction_extra_dict = load_yaml(
|
||||
prediction_extra, keys_fix_yaml=self.keys_fix
|
||||
)
|
||||
if isinstance(original_prediction_dict, dict) and isinstance(
|
||||
prediction_extra_dict, dict
|
||||
):
|
||||
original_prediction_dict["pr_files"].extend(
|
||||
prediction_extra_dict["pr_files"]
|
||||
)
|
||||
new_yaml = yaml.dump(original_prediction_dict)
|
||||
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
|
||||
prediction = new_yaml
|
||||
@ -379,11 +540,12 @@ class PRDescription:
|
||||
get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}")
|
||||
return original_prediction
|
||||
|
||||
|
||||
async def extend_additional_files(self, remaining_files_list) -> str:
|
||||
prediction = self.prediction
|
||||
try:
|
||||
original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix)
|
||||
original_prediction_dict = load_yaml(
|
||||
self.prediction, keys_fix_yaml=self.keys_fix
|
||||
)
|
||||
prediction_extra = "pr_files:"
|
||||
for file in remaining_files_list:
|
||||
extra_file_yaml = f"""\
|
||||
@ -397,10 +559,16 @@ class PRDescription:
|
||||
additional files (token-limit)
|
||||
"""
|
||||
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
||||
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
|
||||
prediction_extra_dict = load_yaml(
|
||||
prediction_extra, keys_fix_yaml=self.keys_fix
|
||||
)
|
||||
# merge the two dictionaries
|
||||
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
|
||||
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
|
||||
if isinstance(original_prediction_dict, dict) and isinstance(
|
||||
prediction_extra_dict, dict
|
||||
):
|
||||
original_prediction_dict["pr_files"].extend(
|
||||
prediction_extra_dict["pr_files"]
|
||||
)
|
||||
new_yaml = yaml.dump(original_prediction_dict)
|
||||
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
|
||||
prediction = new_yaml
|
||||
@ -409,7 +577,9 @@ class PRDescription:
|
||||
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
|
||||
return self.prediction
|
||||
|
||||
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
|
||||
async def _get_prediction(
|
||||
self, model: str, patches_diff: str, prompt="pr_description_prompt"
|
||||
) -> str:
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = patches_diff # update diff
|
||||
|
||||
@ -417,14 +587,18 @@ class PRDescription:
|
||||
set_custom_labels(variables, self.git_provider)
|
||||
self.variables = variables
|
||||
|
||||
system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(self.variables)
|
||||
user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(self.variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().get(prompt, {}).get("system", "")
|
||||
).render(self.variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().get(prompt, {}).get("user", "")
|
||||
).render(self.variables)
|
||||
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt
|
||||
user=user_prompt,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -433,7 +607,10 @@ class PRDescription:
|
||||
# Load the AI prediction data into a dictionary
|
||||
self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix)
|
||||
|
||||
if get_settings().pr_description.add_original_user_description and self.user_description:
|
||||
if (
|
||||
get_settings().pr_description.add_original_user_description
|
||||
and self.user_description
|
||||
):
|
||||
self.data["User Description"] = self.user_description
|
||||
|
||||
# re-order keys
|
||||
@ -459,7 +636,11 @@ class PRDescription:
|
||||
pr_labels = self.data['labels']
|
||||
elif type(self.data['labels']) == str:
|
||||
pr_labels = self.data['labels'].split(',')
|
||||
elif 'type' in self.data and self.data['type'] and get_settings().pr_description.publish_labels:
|
||||
elif (
|
||||
'type' in self.data
|
||||
and self.data['type']
|
||||
and get_settings().pr_description.publish_labels
|
||||
):
|
||||
if type(self.data['type']) == list:
|
||||
pr_labels = self.data['type']
|
||||
elif type(self.data['type']) == str:
|
||||
@ -474,7 +655,9 @@ class PRDescription:
|
||||
if label_i in d:
|
||||
pr_labels[i] = d[label_i]
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
|
||||
get_logger().error(
|
||||
f"Error converting labels to original case {self.pr_id}: {e}"
|
||||
)
|
||||
return pr_labels
|
||||
|
||||
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
|
||||
@ -482,13 +665,13 @@ class PRDescription:
|
||||
|
||||
# Remove the 'PR Title' key from the dictionary
|
||||
ai_title = self.data.pop('title', self.vars["title"])
|
||||
if (not get_settings().pr_description.generate_ai_title):
|
||||
if not get_settings().pr_description.generate_ai_title:
|
||||
# Assign the original PR title to the 'title' variable
|
||||
title = self.vars["title"]
|
||||
else:
|
||||
# Assign the value of the 'PR Title' key to 'title' variable
|
||||
title = ai_title
|
||||
|
||||
|
||||
body = self.user_description
|
||||
if get_settings().pr_description.include_generated_by_header:
|
||||
ai_header = f"### 🤖 Generated by PR Agent at {self.git_provider.last_commit_id.sha}\n\n"
|
||||
@ -514,8 +697,9 @@ class PRDescription:
|
||||
pr_file_changes = []
|
||||
if ai_walkthrough and not re.search(r'<!--\s*pr_agent:walkthrough\s*-->', body):
|
||||
try:
|
||||
walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(walkthrough_gfm,
|
||||
self.file_label_dict)
|
||||
walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(
|
||||
walkthrough_gfm, self.file_label_dict
|
||||
)
|
||||
body = body.replace('pr_agent:walkthrough', walkthrough_gfm)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}")
|
||||
@ -545,7 +729,7 @@ class PRDescription:
|
||||
|
||||
# Remove the 'PR Title' key from the dictionary
|
||||
ai_title = self.data.pop('title', self.vars["title"])
|
||||
if (not get_settings().pr_description.generate_ai_title):
|
||||
if not get_settings().pr_description.generate_ai_title:
|
||||
# Assign the original PR title to the 'title' variable
|
||||
title = self.vars["title"]
|
||||
else:
|
||||
@ -575,13 +759,20 @@ class PRDescription:
|
||||
pr_body += f'- `{filename}`: {description}\n'
|
||||
if self.git_provider.is_supported("gfm_markdown"):
|
||||
pr_body += "</details>\n"
|
||||
elif 'pr_files' in key.lower() and get_settings().pr_description.enable_semantic_files_types:
|
||||
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(changes_walkthrough, value)
|
||||
elif (
|
||||
'pr_files' in key.lower()
|
||||
and get_settings().pr_description.enable_semantic_files_types
|
||||
):
|
||||
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(
|
||||
changes_walkthrough, value
|
||||
)
|
||||
changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}"
|
||||
elif key.lower().strip() == 'description':
|
||||
if isinstance(value, list):
|
||||
value = ', '.join(v.rstrip() for v in value)
|
||||
value = value.replace('\n-', '\n\n-').strip() # makes the bullet points more readable by adding double space
|
||||
value = value.replace(
|
||||
'\n-', '\n\n-'
|
||||
).strip() # makes the bullet points more readable by adding double space
|
||||
pr_body += f"{value}\n"
|
||||
else:
|
||||
# if the value is a list, join its items by comma
|
||||
@ -591,24 +782,37 @@ class PRDescription:
|
||||
if idx < len(self.data) - 1:
|
||||
pr_body += "\n\n___\n\n"
|
||||
|
||||
return title, pr_body, changes_walkthrough, pr_file_changes,
|
||||
return (
|
||||
title,
|
||||
pr_body,
|
||||
changes_walkthrough,
|
||||
pr_file_changes,
|
||||
)
|
||||
|
||||
def _prepare_file_labels(self):
|
||||
file_label_dict = {}
|
||||
if (not self.data or not isinstance(self.data, dict) or
|
||||
'pr_files' not in self.data or not self.data['pr_files']):
|
||||
if (
|
||||
not self.data
|
||||
or not isinstance(self.data, dict)
|
||||
or 'pr_files' not in self.data
|
||||
or not self.data['pr_files']
|
||||
):
|
||||
return file_label_dict
|
||||
for file in self.data['pr_files']:
|
||||
try:
|
||||
required_fields = ['changes_title', 'filename', 'label']
|
||||
if not all(field in file for field in required_fields):
|
||||
# can happen for example if a YAML generation was interrupted in the middle (no more tokens)
|
||||
get_logger().warning(f"Missing required fields in file label dict {self.pr_id}, skipping file",
|
||||
artifact={"file": file})
|
||||
get_logger().warning(
|
||||
f"Missing required fields in file label dict {self.pr_id}, skipping file",
|
||||
artifact={"file": file},
|
||||
)
|
||||
continue
|
||||
if not file.get('changes_title'):
|
||||
get_logger().warning(f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
|
||||
artifact={"file": file})
|
||||
get_logger().warning(
|
||||
f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
|
||||
artifact={"file": file},
|
||||
)
|
||||
continue
|
||||
filename = file['filename'].replace("'", "`").replace('"', '`')
|
||||
changes_summary = file.get('changes_summary', "").strip()
|
||||
@ -616,7 +820,9 @@ class PRDescription:
|
||||
label = file.get('label').strip().lower()
|
||||
if label not in file_label_dict:
|
||||
file_label_dict[label] = []
|
||||
file_label_dict[label].append((filename, changes_title, changes_summary))
|
||||
file_label_dict[label].append(
|
||||
(filename, changes_title, changes_summary)
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}")
|
||||
pass
|
||||
@ -640,7 +846,9 @@ class PRDescription:
|
||||
header = f"相关文件"
|
||||
delta = 75
|
||||
# header += " " * delta
|
||||
pr_body += f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
|
||||
pr_body += (
|
||||
f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
|
||||
)
|
||||
pr_body += """<tbody>"""
|
||||
for semantic_label in value.keys():
|
||||
s_label = semantic_label.strip("'").strip('"')
|
||||
@ -651,14 +859,22 @@ class PRDescription:
|
||||
pr_body += f"""<td><details><summary>{len(list_tuples)} files</summary><table>"""
|
||||
else:
|
||||
pr_body += f"""<td><table>"""
|
||||
for filename, file_changes_title, file_change_description in list_tuples:
|
||||
for (
|
||||
filename,
|
||||
file_changes_title,
|
||||
file_change_description,
|
||||
) in list_tuples:
|
||||
filename = filename.replace("'", "`").rstrip()
|
||||
filename_publish = filename.split("/")[-1]
|
||||
if file_changes_title and file_changes_title.strip() != "...":
|
||||
file_changes_title_code = f"<code>{file_changes_title}</code>"
|
||||
file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip()
|
||||
file_changes_title_code_br = insert_br_after_x_chars(
|
||||
file_changes_title_code, x=(delta - 5)
|
||||
).strip()
|
||||
if len(file_changes_title_code_br) < (delta - 5):
|
||||
file_changes_title_code_br += " " * ((delta - 5) - len(file_changes_title_code_br))
|
||||
file_changes_title_code_br += " " * (
|
||||
(delta - 5) - len(file_changes_title_code_br)
|
||||
)
|
||||
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
|
||||
else:
|
||||
filename_publish = f"<strong>{filename_publish}</strong>"
|
||||
@ -679,15 +895,30 @@ class PRDescription:
|
||||
link = ""
|
||||
if hasattr(self.git_provider, 'get_line_link'):
|
||||
filename = filename.strip()
|
||||
link = self.git_provider.get_line_link(filename, relevant_line_start=-1)
|
||||
if (not link or not diff_plus_minus) and ('additional files' not in filename.lower()):
|
||||
get_logger().warning(f"Error getting line link for '{filename}'")
|
||||
link = self.git_provider.get_line_link(
|
||||
filename, relevant_line_start=-1
|
||||
)
|
||||
if (not link or not diff_plus_minus) and (
|
||||
'additional files' not in filename.lower()
|
||||
):
|
||||
get_logger().warning(
|
||||
f"Error getting line link for '{filename}'"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add file data to the PR body
|
||||
file_change_description_br = insert_br_after_x_chars(file_change_description, x=(delta - 5))
|
||||
pr_body = self.add_file_data(delta_nbsp, diff_plus_minus, file_change_description_br, filename,
|
||||
filename_publish, link, pr_body)
|
||||
file_change_description_br = insert_br_after_x_chars(
|
||||
file_change_description, x=(delta - 5)
|
||||
)
|
||||
pr_body = self.add_file_data(
|
||||
delta_nbsp,
|
||||
diff_plus_minus,
|
||||
file_change_description_br,
|
||||
filename,
|
||||
filename_publish,
|
||||
link,
|
||||
pr_body,
|
||||
)
|
||||
|
||||
# Close the collapsible file list
|
||||
if use_collapsible_file_list:
|
||||
@ -697,13 +928,22 @@ class PRDescription:
|
||||
pr_body += """</tr></tbody></table>"""
|
||||
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing pr files to markdown {self.pr_id}: {str(e)}")
|
||||
get_logger().error(
|
||||
f"Error processing pr files to markdown {self.pr_id}: {str(e)}"
|
||||
)
|
||||
pass
|
||||
return pr_body, pr_comments
|
||||
|
||||
def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link,
|
||||
pr_body) -> str:
|
||||
|
||||
def add_file_data(
|
||||
self,
|
||||
delta_nbsp,
|
||||
diff_plus_minus,
|
||||
file_change_description_br,
|
||||
filename,
|
||||
filename_publish,
|
||||
link,
|
||||
pr_body,
|
||||
) -> str:
|
||||
if not file_change_description_br:
|
||||
pr_body += f"""
|
||||
<tr>
|
||||
@ -735,6 +975,7 @@ class PRDescription:
|
||||
"""
|
||||
return pr_body
|
||||
|
||||
|
||||
def count_chars_without_html(string):
|
||||
if '<' not in string:
|
||||
return len(string)
|
||||
|
||||
@ -16,8 +16,12 @@ from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRGenerateLabels:
|
||||
def __init__(self, pr_url: str, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
):
|
||||
"""
|
||||
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
||||
corresponding to the PR using an AI model.
|
||||
@ -93,7 +97,9 @@ class PRGenerateLabels:
|
||||
elif pr_labels:
|
||||
value = ', '.join(v for v in pr_labels)
|
||||
pr_labels_text = f"## PR Labels:\n{value}\n"
|
||||
self.git_provider.publish_comment(pr_labels_text, is_temporary=False)
|
||||
self.git_provider.publish_comment(
|
||||
pr_labels_text, is_temporary=False
|
||||
)
|
||||
self.git_provider.remove_initial_comment()
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error generating PR labels {self.pr_id}: {e}")
|
||||
@ -137,14 +143,18 @@ class PRGenerateLabels:
|
||||
set_custom_labels(variables, self.git_provider)
|
||||
self.variables = variables
|
||||
|
||||
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(self.variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(self.variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().pr_custom_labels_prompt.system
|
||||
).render(self.variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().pr_custom_labels_prompt.user
|
||||
).render(self.variables)
|
||||
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt
|
||||
user=user_prompt,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -153,8 +163,6 @@ class PRGenerateLabels:
|
||||
# Load the AI prediction data into a dictionary
|
||||
self.data = load_yaml(self.prediction.strip())
|
||||
|
||||
|
||||
|
||||
def _prepare_labels(self) -> List[str]:
|
||||
pr_types = []
|
||||
|
||||
@ -174,6 +182,8 @@ class PRGenerateLabels:
|
||||
if label_i in d:
|
||||
pr_types[i] = d[label_i]
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
|
||||
get_logger().error(
|
||||
f"Error converting labels to original case {self.pr_id}: {e}"
|
||||
)
|
||||
|
||||
return pr_types
|
||||
|
||||
@ -12,7 +12,11 @@ from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context
|
||||
from utils.pr_agent.git_providers import (
|
||||
BitbucketServerProvider,
|
||||
GithubProvider,
|
||||
get_git_provider_with_context,
|
||||
)
|
||||
from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
@ -29,31 +33,50 @@ def extract_header(snippet):
|
||||
res = f"#{highest_header.lower().replace(' ', '-')}"
|
||||
return res
|
||||
|
||||
|
||||
class PRHelpMessage:
|
||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False):
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
args=None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
return_as_string=False,
|
||||
):
|
||||
self.git_provider = get_git_provider_with_context(pr_url)
|
||||
self.ai_handler = ai_handler()
|
||||
self.question_str = self.parse_args(args)
|
||||
self.return_as_string = return_as_string
|
||||
self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5)
|
||||
self.num_retrieved_snippets = get_settings().get(
|
||||
'pr_help.num_retrieved_snippets', 5
|
||||
)
|
||||
if self.question_str:
|
||||
self.vars = {
|
||||
"question": self.question_str,
|
||||
"snippets": "",
|
||||
}
|
||||
self.token_handler = TokenHandler(None,
|
||||
self.vars,
|
||||
get_settings().pr_help_prompts.system,
|
||||
get_settings().pr_help_prompts.user)
|
||||
self.token_handler = TokenHandler(
|
||||
None,
|
||||
self.vars,
|
||||
get_settings().pr_help_prompts.system,
|
||||
get_settings().pr_help_prompts.user,
|
||||
)
|
||||
|
||||
async def _prepare_prediction(self, model: str):
|
||||
try:
|
||||
variables = copy.deepcopy(self.vars)
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_help_prompts.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_help_prompts.user).render(variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().pr_help_prompts.system
|
||||
).render(variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().pr_help_prompts.user
|
||||
).render(variables)
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error while preparing prediction: {e}")
|
||||
@ -81,7 +104,7 @@ class PRHelpMessage:
|
||||
'.': '',
|
||||
'?': '',
|
||||
'!': '',
|
||||
' ': '-'
|
||||
' ': '-',
|
||||
}
|
||||
|
||||
# Compile regex pattern for characters to remove
|
||||
@ -90,37 +113,69 @@ class PRHelpMessage:
|
||||
# Perform replacements in a single pass and convert to lowercase
|
||||
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
|
||||
except Exception:
|
||||
get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header})
|
||||
get_logger().exception(
|
||||
f"Error while formatting markdown header", artifacts={'header': header}
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
if self.question_str:
|
||||
get_logger().info(f'Answering a PR question about the PR {self.git_provider.pr_url} ')
|
||||
get_logger().info(
|
||||
f'Answering a PR question about the PR {self.git_provider.pr_url} '
|
||||
)
|
||||
|
||||
if not get_settings().get('openai.key'):
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment(
|
||||
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
|
||||
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
|
||||
)
|
||||
else:
|
||||
get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
|
||||
get_logger().error(
|
||||
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
|
||||
)
|
||||
return
|
||||
|
||||
# current path
|
||||
docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs'
|
||||
docs_path = Path(__file__).parent.parent.parent / 'docs' / 'docs'
|
||||
# get all the 'md' files inside docs_path and its subdirectories
|
||||
md_files = list(docs_path.glob('**/*.md'))
|
||||
folders_to_exclude = ['/finetuning_benchmark/']
|
||||
files_to_exclude = {'EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md'}
|
||||
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
|
||||
files_to_exclude = {
|
||||
'EXAMPLE_BEST_PRACTICE.md',
|
||||
'compression_strategy.md',
|
||||
'/docs/overview/index.md',
|
||||
}
|
||||
md_files = [
|
||||
file
|
||||
for file in md_files
|
||||
if not any(folder in str(file) for folder in folders_to_exclude)
|
||||
and not any(
|
||||
file.name == file_to_exclude
|
||||
for file_to_exclude in files_to_exclude
|
||||
)
|
||||
]
|
||||
|
||||
# sort the 'md_files' so that 'priority_files' will be at the top
|
||||
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
|
||||
'tools/improve.md', '/faq']
|
||||
md_files_priority = [file for file in md_files if
|
||||
any(priority_string in str(file) for priority_string in priority_files_strings)]
|
||||
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
|
||||
priority_files_strings = [
|
||||
'/docs/index.md',
|
||||
'/usage-guide',
|
||||
'tools/describe.md',
|
||||
'tools/review.md',
|
||||
'tools/improve.md',
|
||||
'/faq',
|
||||
]
|
||||
md_files_priority = [
|
||||
file
|
||||
for file in md_files
|
||||
if any(
|
||||
priority_string in str(file)
|
||||
for priority_string in priority_files_strings
|
||||
)
|
||||
]
|
||||
md_files_not_priority = [
|
||||
file for file in md_files if file not in md_files_priority
|
||||
]
|
||||
md_files = md_files_priority + md_files_not_priority
|
||||
|
||||
docs_prompt = ""
|
||||
@ -132,24 +187,36 @@ class PRHelpMessage:
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error while reading the file {file}: {e}")
|
||||
token_count = self.token_handler.count_tokens(docs_prompt)
|
||||
get_logger().debug(f"Token count of full documentation website: {token_count}")
|
||||
get_logger().debug(
|
||||
f"Token count of full documentation website: {token_count}"
|
||||
)
|
||||
|
||||
model = get_settings().config.model
|
||||
if model in MAX_TOKENS:
|
||||
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||
max_tokens_full = MAX_TOKENS[
|
||||
model
|
||||
] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||
else:
|
||||
max_tokens_full = get_max_tokens(model)
|
||||
delta_output = 2000
|
||||
if token_count > max_tokens_full - delta_output:
|
||||
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
|
||||
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
|
||||
get_logger().info(
|
||||
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message."
|
||||
)
|
||||
docs_prompt = clip_tokens(
|
||||
docs_prompt, max_tokens_full - delta_output
|
||||
)
|
||||
self.vars['snippets'] = docs_prompt.strip()
|
||||
|
||||
# run the AI model
|
||||
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||
response = await retry_with_fallback_models(
|
||||
self._prepare_prediction, model_type=ModelType.REGULAR
|
||||
)
|
||||
response_yaml = load_yaml(response)
|
||||
if isinstance(response_yaml, str):
|
||||
get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is")
|
||||
get_logger().warning(
|
||||
f"failing to parse response: {response_yaml}, publishing the response as is"
|
||||
)
|
||||
if get_settings().config.publish_output:
|
||||
answer_str = f"### Question: \n{self.question_str}\n\n"
|
||||
answer_str += f"### Answer:\n\n"
|
||||
@ -160,7 +227,9 @@ class PRHelpMessage:
|
||||
relevant_sections = response_yaml.get('relevant_sections')
|
||||
|
||||
if not relevant_sections:
|
||||
get_logger().info(f"Could not find relevant answer for the question: {self.question_str}")
|
||||
get_logger().info(
|
||||
f"Could not find relevant answer for the question: {self.question_str}"
|
||||
)
|
||||
if get_settings().config.publish_output:
|
||||
answer_str = f"### Question: \n{self.question_str}\n\n"
|
||||
answer_str += f"### Answer:\n\n"
|
||||
@ -178,29 +247,38 @@ class PRHelpMessage:
|
||||
for section in relevant_sections:
|
||||
file = section.get('file_name').strip().removesuffix('.md')
|
||||
if str(section['relevant_section_header_string']).strip():
|
||||
markdown_header = self.format_markdown_header(section['relevant_section_header_string'])
|
||||
markdown_header = self.format_markdown_header(
|
||||
section['relevant_section_header_string']
|
||||
)
|
||||
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
|
||||
else:
|
||||
answer_str += f"> - {base_path}{file}\n"
|
||||
|
||||
|
||||
# publish the answer
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment(answer_str)
|
||||
else:
|
||||
get_logger().info(f"Answer:\n{answer_str}")
|
||||
else:
|
||||
if not isinstance(self.git_provider, BitbucketServerProvider) and not self.git_provider.is_supported("gfm_markdown"):
|
||||
if not isinstance(
|
||||
self.git_provider, BitbucketServerProvider
|
||||
) and not self.git_provider.is_supported("gfm_markdown"):
|
||||
self.git_provider.publish_comment(
|
||||
"The `Help` tool requires gfm markdown, which is not supported by your code platform.")
|
||||
"The `Help` tool requires gfm markdown, which is not supported by your code platform."
|
||||
)
|
||||
return
|
||||
|
||||
get_logger().info('Getting PR Help Message...')
|
||||
relevant_configs = {'pr_help': dict(get_settings().pr_help),
|
||||
'config': dict(get_settings().config)}
|
||||
relevant_configs = {
|
||||
'pr_help': dict(get_settings().pr_help),
|
||||
'config': dict(get_settings().config),
|
||||
}
|
||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||
pr_comment = "## PR Agent Walkthrough 🤖\n\n"
|
||||
pr_comment += "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."""
|
||||
pr_comment += (
|
||||
"Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."
|
||||
""
|
||||
)
|
||||
pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n"
|
||||
base_path = "https://pr-agent-docs.codium.ai/tools"
|
||||
|
||||
@ -211,32 +289,58 @@ class PRHelpMessage:
|
||||
tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)")
|
||||
tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎")
|
||||
tool_names.append(f"[TEST]({base_path}/test/) 💎")
|
||||
tool_names.append(f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎")
|
||||
tool_names.append(
|
||||
f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎"
|
||||
)
|
||||
tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎")
|
||||
tool_names.append(f"[ASK]({base_path}/ask/)")
|
||||
tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)")
|
||||
tool_names.append(f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎")
|
||||
tool_names.append(
|
||||
f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎"
|
||||
)
|
||||
tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎")
|
||||
tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎")
|
||||
tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎")
|
||||
|
||||
descriptions = []
|
||||
descriptions.append("Generates PR description - title, type, summary, code walkthrough and labels")
|
||||
descriptions.append("Adjustable feedback about the PR, possible issues, security concerns, review effort and more")
|
||||
descriptions.append(
|
||||
"Generates PR description - title, type, summary, code walkthrough and labels"
|
||||
)
|
||||
descriptions.append(
|
||||
"Adjustable feedback about the PR, possible issues, security concerns, review effort and more"
|
||||
)
|
||||
descriptions.append("Code suggestions for improving the PR")
|
||||
descriptions.append("Automatically updates the changelog")
|
||||
descriptions.append("Generates documentation to methods/functions/classes that changed in the PR")
|
||||
descriptions.append("Generates unit tests for a specific component, based on the PR code change")
|
||||
descriptions.append("Code suggestions for a specific component that changed in the PR")
|
||||
descriptions.append("Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component")
|
||||
descriptions.append(
|
||||
"Generates documentation to methods/functions/classes that changed in the PR"
|
||||
)
|
||||
descriptions.append(
|
||||
"Generates unit tests for a specific component, based on the PR code change"
|
||||
)
|
||||
descriptions.append(
|
||||
"Code suggestions for a specific component that changed in the PR"
|
||||
)
|
||||
descriptions.append(
|
||||
"Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component"
|
||||
)
|
||||
descriptions.append("Answering free-text questions about the PR")
|
||||
descriptions.append("Automatically retrieves and presents similar issues")
|
||||
descriptions.append("Generates custom labels for the PR, based on specific guidelines defined by the user")
|
||||
descriptions.append("Generates feedback and analysis for a failed CI job")
|
||||
descriptions.append("Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user")
|
||||
descriptions.append("Generates implementation code from review suggestions")
|
||||
descriptions.append(
|
||||
"Automatically retrieves and presents similar issues"
|
||||
)
|
||||
descriptions.append(
|
||||
"Generates custom labels for the PR, based on specific guidelines defined by the user"
|
||||
)
|
||||
descriptions.append(
|
||||
"Generates feedback and analysis for a failed CI job"
|
||||
)
|
||||
descriptions.append(
|
||||
"Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user"
|
||||
)
|
||||
descriptions.append(
|
||||
"Generates implementation code from review suggestions"
|
||||
)
|
||||
|
||||
commands =[]
|
||||
commands = []
|
||||
commands.append("`/describe`")
|
||||
commands.append("`/review`")
|
||||
commands.append("`/improve`")
|
||||
@ -271,7 +375,9 @@ class PRHelpMessage:
|
||||
checkbox_list.append("[*]")
|
||||
checkbox_list.append("[*]")
|
||||
|
||||
if isinstance(self.git_provider, GithubProvider) and not get_settings().config.get('disable_checkboxes', False):
|
||||
if isinstance(
|
||||
self.git_provider, GithubProvider
|
||||
) and not get_settings().config.get('disable_checkboxes', False):
|
||||
pr_comment += f"<table><tr align='left'><th align='left'>Tool</th><th align='left'>Description</th><th align='left'>Trigger Interactively :gem:</th></tr>"
|
||||
for i in range(len(tool_names)):
|
||||
pr_comment += f"\n<tr><td align='left'>\n\n<strong>{tool_names[i]}</strong></td>\n<td>{descriptions[i]}</td>\n<td>\n\n{checkbox_list[i]}\n</td></tr>"
|
||||
|
||||
@ -5,8 +5,7 @@ from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from utils.pr_agent.algo.git_patch_processing import (
|
||||
extract_hunk_lines_from_patch)
|
||||
from utils.pr_agent.algo.git_patch_processing import extract_hunk_lines_from_patch
|
||||
from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.utils import ModelType
|
||||
@ -17,7 +16,12 @@ from utils.pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PR_LineQuestions:
|
||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
args=None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
):
|
||||
self.question_str = self.parse_args(args)
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_pr_language = get_main_pr_language(
|
||||
@ -34,10 +38,12 @@ class PR_LineQuestions:
|
||||
"full_hunk": "",
|
||||
"selected_lines": "",
|
||||
}
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_line_questions_prompt.system,
|
||||
get_settings().pr_line_questions_prompt.user)
|
||||
self.token_handler = TokenHandler(
|
||||
self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_line_questions_prompt.system,
|
||||
get_settings().pr_line_questions_prompt.user,
|
||||
)
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
|
||||
@ -48,7 +54,6 @@ class PR_LineQuestions:
|
||||
question_str = ""
|
||||
return question_str
|
||||
|
||||
|
||||
async def run(self):
|
||||
get_logger().info('Answering a PR lines question...')
|
||||
# if get_settings().config.publish_output:
|
||||
@ -62,22 +67,27 @@ class PR_LineQuestions:
|
||||
file_name = get_settings().get('file_name', '')
|
||||
comment_id = get_settings().get('comment_id', '')
|
||||
if ask_diff:
|
||||
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(ask_diff,
|
||||
file_name,
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
side=side
|
||||
)
|
||||
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(
|
||||
ask_diff, file_name, line_start=line_start, line_end=line_end, side=side
|
||||
)
|
||||
else:
|
||||
diff_files = self.git_provider.get_diff_files()
|
||||
for file in diff_files:
|
||||
if file.filename == file_name:
|
||||
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(file.patch, file.filename,
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
side=side)
|
||||
(
|
||||
self.patch_with_lines,
|
||||
self.selected_lines,
|
||||
) = extract_hunk_lines_from_patch(
|
||||
file.patch,
|
||||
file.filename,
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
side=side,
|
||||
)
|
||||
if self.patch_with_lines:
|
||||
model_answer = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK)
|
||||
model_answer = await retry_with_fallback_models(
|
||||
self._get_prediction, model_type=ModelType.WEAK
|
||||
)
|
||||
# sanitize the answer so that no line will start with "/"
|
||||
model_answer_sanitized = model_answer.strip().replace("\n/", "\n /")
|
||||
if model_answer_sanitized.startswith("/"):
|
||||
@ -85,7 +95,9 @@ class PR_LineQuestions:
|
||||
|
||||
get_logger().info('Preparing answer...')
|
||||
if comment_id:
|
||||
self.git_provider.reply_to_comment_from_comment_id(comment_id, model_answer_sanitized)
|
||||
self.git_provider.reply_to_comment_from_comment_id(
|
||||
comment_id, model_answer_sanitized
|
||||
)
|
||||
else:
|
||||
self.git_provider.publish_comment(model_answer_sanitized)
|
||||
|
||||
@ -96,8 +108,12 @@ class PR_LineQuestions:
|
||||
variables["full_hunk"] = self.patch_with_lines # update diff
|
||||
variables["selected_lines"] = self.selected_lines
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_line_questions_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_line_questions_prompt.user).render(variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().pr_line_questions_prompt.system
|
||||
).render(variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().pr_line_questions_prompt.user
|
||||
).render(variables)
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
# get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
||||
# get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
||||
@ -105,5 +121,9 @@ class PR_LineQuestions:
|
||||
print(f"\nUser prompt:\n{user_prompt}")
|
||||
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt,
|
||||
)
|
||||
return response
|
||||
|
||||
@ -16,7 +16,12 @@ from utils.pr_agent.servers.help import HelpMessage
|
||||
|
||||
|
||||
class PRQuestions:
|
||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
args=None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
):
|
||||
question_str = self.parse_args(args)
|
||||
self.pr_url = pr_url
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
@ -36,10 +41,12 @@ class PRQuestions:
|
||||
"questions": self.question_str,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
}
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_questions_prompt.system,
|
||||
get_settings().pr_questions_prompt.user)
|
||||
self.token_handler = TokenHandler(
|
||||
self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_questions_prompt.system,
|
||||
get_settings().pr_questions_prompt.user,
|
||||
)
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
|
||||
@ -52,8 +59,10 @@ class PRQuestions:
|
||||
|
||||
async def run(self):
|
||||
get_logger().info(f'Answering a PR question about the PR {self.pr_url} ')
|
||||
relevant_configs = {'pr_questions': dict(get_settings().pr_questions),
|
||||
'config': dict(get_settings().config)}
|
||||
relevant_configs = {
|
||||
'pr_questions': dict(get_settings().pr_questions),
|
||||
'config': dict(get_settings().config),
|
||||
}
|
||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment("思考回答中...", is_temporary=True)
|
||||
@ -63,12 +72,17 @@ class PRQuestions:
|
||||
if img_path:
|
||||
get_logger().debug(f"Image path identified", artifact=img_path)
|
||||
|
||||
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
|
||||
await retry_with_fallback_models(
|
||||
self._prepare_prediction, model_type=ModelType.WEAK
|
||||
)
|
||||
|
||||
pr_comment = self._prepare_pr_answer()
|
||||
get_logger().debug(f"PR output", artifact=pr_comment)
|
||||
|
||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_questions.enable_help_text:
|
||||
if (
|
||||
self.git_provider.is_supported("gfm_markdown")
|
||||
and get_settings().pr_questions.enable_help_text
|
||||
):
|
||||
pr_comment += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
||||
pr_comment += HelpMessage.get_ask_usage_guide()
|
||||
pr_comment += "\n</details>\n"
|
||||
@ -85,7 +99,9 @@ class PRQuestions:
|
||||
# /ask question ... > 
|
||||
img_path = self.question_str.split('![image]')[1].strip().strip('()')
|
||||
self.vars['img_path'] = img_path
|
||||
elif 'https://' in self.question_str and ('.png' in self.question_str or 'jpg' in self.question_str): # direct image link
|
||||
elif 'https://' in self.question_str and (
|
||||
'.png' in self.question_str or 'jpg' in self.question_str
|
||||
): # direct image link
|
||||
# include https:// in the image path
|
||||
img_path = 'https://' + self.question_str.split('https://')[1]
|
||||
self.vars['img_path'] = img_path
|
||||
@ -104,16 +120,28 @@ class PRQuestions:
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().pr_questions_prompt.system
|
||||
).render(variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().pr_questions_prompt.user
|
||||
).render(variables)
|
||||
if 'img_path' in variables:
|
||||
img_path = self.vars['img_path']
|
||||
response, finish_reason = await (self.ai_handler.chat_completion
|
||||
(model=model, temperature=get_settings().config.temperature,
|
||||
system=system_prompt, user=user_prompt, img_path=img_path))
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt,
|
||||
img_path=img_path,
|
||||
)
|
||||
else:
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt,
|
||||
)
|
||||
return response
|
||||
|
||||
def _prepare_pr_answer(self) -> str:
|
||||
@ -123,9 +151,13 @@ class PRQuestions:
|
||||
if model_answer_sanitized.startswith("/"):
|
||||
model_answer_sanitized = " " + model_answer_sanitized
|
||||
if model_answer_sanitized != model_answer:
|
||||
get_logger().debug(f"Sanitized model answer",
|
||||
artifact={"model_answer": model_answer, "sanitized_answer": model_answer_sanitized})
|
||||
|
||||
get_logger().debug(
|
||||
f"Sanitized model answer",
|
||||
artifact={
|
||||
"model_answer": model_answer,
|
||||
"sanitized_answer": model_answer_sanitized,
|
||||
},
|
||||
)
|
||||
|
||||
answer_str = f"### **Ask**❓\n{self.question_str}\n\n"
|
||||
answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n"
|
||||
|
||||
@ -7,21 +7,29 @@ from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
|
||||
get_pr_diff,
|
||||
retry_with_fallback_models)
|
||||
from utils.pr_agent.algo.pr_processing import (
|
||||
add_ai_metadata_to_diff_files,
|
||||
get_pr_diff,
|
||||
retry_with_fallback_models,
|
||||
)
|
||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||
from utils.pr_agent.algo.utils import (ModelType, PRReviewHeader,
|
||||
convert_to_markdown_v2, github_action_output,
|
||||
load_yaml, show_relevant_configurations)
|
||||
from utils.pr_agent.algo.utils import (
|
||||
ModelType,
|
||||
PRReviewHeader,
|
||||
convert_to_markdown_v2,
|
||||
github_action_output,
|
||||
load_yaml,
|
||||
show_relevant_configurations,
|
||||
)
|
||||
from utils.pr_agent.config_loader import get_settings
|
||||
from utils.pr_agent.git_providers import (get_git_provider_with_context)
|
||||
from utils.pr_agent.git_providers.git_provider import (IncrementalPR,
|
||||
get_main_pr_language)
|
||||
from utils.pr_agent.git_providers import get_git_provider_with_context
|
||||
from utils.pr_agent.git_providers.git_provider import (
|
||||
IncrementalPR,
|
||||
get_main_pr_language,
|
||||
)
|
||||
from utils.pr_agent.log import get_logger
|
||||
from utils.pr_agent.servers.help import HelpMessage
|
||||
from utils.pr_agent.tools.ticket_pr_compliance_check import (
|
||||
extract_and_cache_pr_tickets)
|
||||
from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
|
||||
|
||||
|
||||
class PRReviewer:
|
||||
@ -29,8 +37,14 @@ class PRReviewer:
|
||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||
"""
|
||||
|
||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
is_answer: bool = False,
|
||||
is_auto: bool = False,
|
||||
args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
):
|
||||
"""
|
||||
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
||||
|
||||
@ -55,16 +69,23 @@ class PRReviewer:
|
||||
self.is_auto = is_auto
|
||||
|
||||
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
||||
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
|
||||
raise Exception(
|
||||
f"Answer mode is not supported for {get_settings().config.git_provider} for now"
|
||||
)
|
||||
self.ai_handler = ai_handler()
|
||||
self.ai_handler.main_pr_language = self.main_language
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
answer_str, question_str = self._get_user_answers()
|
||||
self.pr_description, self.pr_description_files = (
|
||||
self.git_provider.get_pr_description(split_changes_walkthrough=True))
|
||||
if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
|
||||
get_settings().get("config.enable_ai_metadata", False)):
|
||||
(
|
||||
self.pr_description,
|
||||
self.pr_description_files,
|
||||
) = self.git_provider.get_pr_description(split_changes_walkthrough=True)
|
||||
if (
|
||||
self.pr_description_files
|
||||
and get_settings().get("config.is_auto_command", False)
|
||||
and get_settings().get("config.enable_ai_metadata", False)
|
||||
):
|
||||
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
|
||||
get_logger().debug(f"AI metadata added to the this command")
|
||||
else:
|
||||
@ -89,9 +110,11 @@ class PRReviewer:
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
"custom_labels": "",
|
||||
"enable_custom_labels": get_settings().config.enable_custom_labels,
|
||||
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
|
||||
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
|
||||
"related_tickets": get_settings().get('related_tickets', []),
|
||||
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
|
||||
'duplicate_prompt_examples': get_settings().config.get(
|
||||
'duplicate_prompt_examples', False
|
||||
),
|
||||
"date": datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
}
|
||||
|
||||
@ -99,7 +122,7 @@ class PRReviewer:
|
||||
self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_review_prompt.system,
|
||||
get_settings().pr_review_prompt.user
|
||||
get_settings().pr_review_prompt.user,
|
||||
)
|
||||
|
||||
def parse_incremental(self, args: List[str]):
|
||||
@ -117,7 +140,10 @@ class PRReviewer:
|
||||
get_logger().info(f"PR has no files: {self.pr_url}, skipping review")
|
||||
return None
|
||||
|
||||
if self.incremental.is_incremental and not self._can_run_incremental_review():
|
||||
if (
|
||||
self.incremental.is_incremental
|
||||
and not self._can_run_incremental_review()
|
||||
):
|
||||
return None
|
||||
|
||||
# if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve':
|
||||
@ -126,27 +152,41 @@ class PRReviewer:
|
||||
# return None
|
||||
|
||||
get_logger().info(f'Reviewing PR: {self.pr_url} ...')
|
||||
relevant_configs = {'pr_reviewer': dict(get_settings().pr_reviewer),
|
||||
'config': dict(get_settings().config)}
|
||||
relevant_configs = {
|
||||
'pr_reviewer': dict(get_settings().pr_reviewer),
|
||||
'config': dict(get_settings().config),
|
||||
}
|
||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||
|
||||
# ticket extraction if exists
|
||||
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
|
||||
|
||||
if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set:
|
||||
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files")
|
||||
if (
|
||||
self.incremental.is_incremental
|
||||
and hasattr(self.git_provider, "unreviewed_files_set")
|
||||
and not self.git_provider.unreviewed_files_set
|
||||
):
|
||||
get_logger().info(
|
||||
f"Incremental review is enabled for {self.pr_url} but there are no new files"
|
||||
)
|
||||
previous_review_url = ""
|
||||
if hasattr(self.git_provider, "previous_review"):
|
||||
previous_review_url = self.git_provider.previous_review.html_url
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment(f"Incremental Review Skipped\n"
|
||||
f"No files were changed since the [previous PR Review]({previous_review_url})")
|
||||
self.git_provider.publish_comment(
|
||||
f"Incremental Review Skipped\n"
|
||||
f"No files were changed since the [previous PR Review]({previous_review_url})"
|
||||
)
|
||||
return None
|
||||
|
||||
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
|
||||
if get_settings().config.publish_output and not get_settings().config.get(
|
||||
'is_auto_command', False
|
||||
):
|
||||
self.git_provider.publish_comment("准备评审中...", is_temporary=True)
|
||||
|
||||
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
||||
await retry_with_fallback_models(
|
||||
self._prepare_prediction, model_type=ModelType.REGULAR
|
||||
)
|
||||
if not self.prediction:
|
||||
self.git_provider.remove_initial_comment()
|
||||
return None
|
||||
@ -156,12 +196,19 @@ class PRReviewer:
|
||||
|
||||
if get_settings().config.publish_output:
|
||||
# publish the review
|
||||
if get_settings().pr_reviewer.persistent_comment and not self.incremental.is_incremental:
|
||||
final_update_message = get_settings().pr_reviewer.final_update_message
|
||||
self.git_provider.publish_persistent_comment(pr_review,
|
||||
initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
|
||||
update_header=True,
|
||||
final_update_message=final_update_message, )
|
||||
if (
|
||||
get_settings().pr_reviewer.persistent_comment
|
||||
and not self.incremental.is_incremental
|
||||
):
|
||||
final_update_message = (
|
||||
get_settings().pr_reviewer.final_update_message
|
||||
)
|
||||
self.git_provider.publish_persistent_comment(
|
||||
pr_review,
|
||||
initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
|
||||
update_header=True,
|
||||
final_update_message=final_update_message,
|
||||
)
|
||||
else:
|
||||
self.git_provider.publish_comment(pr_review)
|
||||
|
||||
@ -174,11 +221,13 @@ class PRReviewer:
|
||||
get_logger().error(f"Failed to review PR: {e}")
|
||||
|
||||
async def _prepare_prediction(self, model: str) -> None:
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=False,)
|
||||
self.patches_diff = get_pr_diff(
|
||||
self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=False,
|
||||
)
|
||||
|
||||
if self.patches_diff:
|
||||
get_logger().debug(f"PR diff", diff=self.patches_diff)
|
||||
@ -201,14 +250,18 @@ class PRReviewer:
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().pr_review_prompt.system
|
||||
).render(variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().pr_review_prompt.user
|
||||
).render(variables)
|
||||
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model,
|
||||
temperature=get_settings().config.temperature,
|
||||
system=system_prompt,
|
||||
user=user_prompt
|
||||
user=user_prompt,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -220,10 +273,20 @@ class PRReviewer:
|
||||
"""
|
||||
first_key = 'review'
|
||||
last_key = 'security_concerns'
|
||||
data = load_yaml(self.prediction.strip(),
|
||||
keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
|
||||
"relevant_file:", "relevant_line:", "suggestion:"],
|
||||
first_key=first_key, last_key=last_key)
|
||||
data = load_yaml(
|
||||
self.prediction.strip(),
|
||||
keys_fix_yaml=[
|
||||
"ticket_compliance_check",
|
||||
"estimated_effort_to_review_[1-5]:",
|
||||
"security_concerns:",
|
||||
"key_issues_to_review:",
|
||||
"relevant_file:",
|
||||
"relevant_line:",
|
||||
"suggestion:",
|
||||
],
|
||||
first_key=first_key,
|
||||
last_key=last_key,
|
||||
)
|
||||
github_action_output(data, 'review')
|
||||
|
||||
# move data['review'] 'key_issues_to_review' key to the end of the dictionary
|
||||
@ -234,24 +297,38 @@ class PRReviewer:
|
||||
incremental_review_markdown_text = None
|
||||
# Add incremental review section
|
||||
if self.incremental.is_incremental:
|
||||
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
|
||||
f"{self.git_provider.incremental.first_new_commit_sha}"
|
||||
last_commit_url = (
|
||||
f"{self.git_provider.get_pr_url()}/commits/"
|
||||
f"{self.git_provider.incremental.first_new_commit_sha}"
|
||||
)
|
||||
incremental_review_markdown_text = f"Starting from commit {last_commit_url}"
|
||||
|
||||
markdown_text = convert_to_markdown_v2(data, self.git_provider.is_supported("gfm_markdown"),
|
||||
incremental_review_markdown_text,
|
||||
git_provider=self.git_provider,
|
||||
files=self.git_provider.get_diff_files())
|
||||
markdown_text = convert_to_markdown_v2(
|
||||
data,
|
||||
self.git_provider.is_supported("gfm_markdown"),
|
||||
incremental_review_markdown_text,
|
||||
git_provider=self.git_provider,
|
||||
files=self.git_provider.get_diff_files(),
|
||||
)
|
||||
|
||||
# Add help text if gfm_markdown is supported
|
||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_reviewer.enable_help_text:
|
||||
if (
|
||||
self.git_provider.is_supported("gfm_markdown")
|
||||
and get_settings().pr_reviewer.enable_help_text
|
||||
):
|
||||
markdown_text += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
||||
markdown_text += HelpMessage.get_review_usage_guide()
|
||||
markdown_text += "\n</details>\n"
|
||||
|
||||
# Output the relevant configurations if enabled
|
||||
if get_settings().get('config', {}).get('output_relevant_configurations', False):
|
||||
markdown_text += show_relevant_configurations(relevant_section='pr_reviewer')
|
||||
if (
|
||||
get_settings()
|
||||
.get('config', {})
|
||||
.get('output_relevant_configurations', False)
|
||||
):
|
||||
markdown_text += show_relevant_configurations(
|
||||
relevant_section='pr_reviewer'
|
||||
)
|
||||
|
||||
# Add custom labels from the review prediction (effort, security)
|
||||
self.set_review_labels(data)
|
||||
@ -306,34 +383,50 @@ class PRReviewer:
|
||||
if comment:
|
||||
self.git_provider.remove_comment(comment)
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to remove previous review comment, error: {e}")
|
||||
get_logger().exception(
|
||||
f"Failed to remove previous review comment, error: {e}"
|
||||
)
|
||||
|
||||
def _can_run_incremental_review(self) -> bool:
|
||||
"""Checks if we can run incremental review according the various configurations and previous review"""
|
||||
# checking if running is auto mode but there are no new commits
|
||||
if self.is_auto and not self.incremental.first_new_commit_sha:
|
||||
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new commits")
|
||||
get_logger().info(
|
||||
f"Incremental review is enabled for {self.pr_url} but there are no new commits"
|
||||
)
|
||||
return False
|
||||
|
||||
if not hasattr(self.git_provider, "get_incremental_commits"):
|
||||
get_logger().info(f"Incremental review is not supported for {get_settings().config.git_provider}")
|
||||
get_logger().info(
|
||||
f"Incremental review is not supported for {get_settings().config.git_provider}"
|
||||
)
|
||||
return False
|
||||
# checking if there are enough commits to start the review
|
||||
num_new_commits = len(self.incremental.commits_range)
|
||||
num_commits_threshold = get_settings().pr_reviewer.minimal_commits_for_incremental_review
|
||||
num_commits_threshold = (
|
||||
get_settings().pr_reviewer.minimal_commits_for_incremental_review
|
||||
)
|
||||
not_enough_commits = num_new_commits < num_commits_threshold
|
||||
# checking if the commits are not too recent to start the review
|
||||
recent_commits_threshold = datetime.datetime.now() - datetime.timedelta(
|
||||
minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review
|
||||
)
|
||||
last_seen_commit_date = (
|
||||
self.incremental.last_seen_commit.commit.author.date if self.incremental.last_seen_commit else None
|
||||
self.incremental.last_seen_commit.commit.author.date
|
||||
if self.incremental.last_seen_commit
|
||||
else None
|
||||
)
|
||||
all_commits_too_recent = (
|
||||
last_seen_commit_date > recent_commits_threshold if self.incremental.last_seen_commit else False
|
||||
last_seen_commit_date > recent_commits_threshold
|
||||
if self.incremental.last_seen_commit
|
||||
else False
|
||||
)
|
||||
# check all the thresholds or just one to start the review
|
||||
condition = any if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review else all
|
||||
condition = (
|
||||
any
|
||||
if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review
|
||||
else all
|
||||
)
|
||||
if condition((not_enough_commits, all_commits_too_recent)):
|
||||
get_logger().info(
|
||||
f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:"
|
||||
@ -348,31 +441,55 @@ class PRReviewer:
|
||||
return
|
||||
|
||||
if not get_settings().pr_reviewer.require_estimate_effort_to_review:
|
||||
get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output
|
||||
get_settings().pr_reviewer.enable_review_labels_effort = (
|
||||
False # we did not generate this output
|
||||
)
|
||||
if not get_settings().pr_reviewer.require_security_review:
|
||||
get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output
|
||||
get_settings().pr_reviewer.enable_review_labels_security = (
|
||||
False # we did not generate this output
|
||||
)
|
||||
|
||||
if (get_settings().pr_reviewer.enable_review_labels_security or
|
||||
get_settings().pr_reviewer.enable_review_labels_effort):
|
||||
if (
|
||||
get_settings().pr_reviewer.enable_review_labels_security
|
||||
or get_settings().pr_reviewer.enable_review_labels_effort
|
||||
):
|
||||
try:
|
||||
review_labels = []
|
||||
if get_settings().pr_reviewer.enable_review_labels_effort:
|
||||
estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
|
||||
estimated_effort = data['review'][
|
||||
'estimated_effort_to_review_[1-5]'
|
||||
]
|
||||
estimated_effort_number = 0
|
||||
if isinstance(estimated_effort, str):
|
||||
try:
|
||||
estimated_effort_number = int(estimated_effort.split(',')[0])
|
||||
estimated_effort_number = int(
|
||||
estimated_effort.split(',')[0]
|
||||
)
|
||||
except ValueError:
|
||||
get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}")
|
||||
get_logger().warning(
|
||||
f"Invalid estimated_effort value: {estimated_effort}"
|
||||
)
|
||||
elif isinstance(estimated_effort, int):
|
||||
estimated_effort_number = estimated_effort
|
||||
else:
|
||||
get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}")
|
||||
get_logger().warning(
|
||||
f"Unexpected type for estimated_effort: {type(estimated_effort)}"
|
||||
)
|
||||
if 1 <= estimated_effort_number <= 5: # 1, because ...
|
||||
review_labels.append(f'Review effort {estimated_effort_number}/5')
|
||||
if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review:
|
||||
security_concerns = data['review']['security_concerns'] # yes, because ...
|
||||
security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower()
|
||||
review_labels.append(
|
||||
f'Review effort {estimated_effort_number}/5'
|
||||
)
|
||||
if (
|
||||
get_settings().pr_reviewer.enable_review_labels_security
|
||||
and get_settings().pr_reviewer.require_security_review
|
||||
):
|
||||
security_concerns = data['review'][
|
||||
'security_concerns'
|
||||
] # yes, because ...
|
||||
security_concerns_bool = (
|
||||
'yes' in security_concerns.lower()
|
||||
or 'true' in security_concerns.lower()
|
||||
)
|
||||
if security_concerns_bool:
|
||||
review_labels.append('Possible security concern')
|
||||
|
||||
@ -381,17 +498,26 @@ class PRReviewer:
|
||||
current_labels = []
|
||||
get_logger().debug(f"Current labels:\n{current_labels}")
|
||||
if current_labels:
|
||||
current_labels_filtered = [label for label in current_labels if
|
||||
not label.lower().startswith('review effort') and not label.lower().startswith(
|
||||
'possible security concern')]
|
||||
current_labels_filtered = [
|
||||
label
|
||||
for label in current_labels
|
||||
if not label.lower().startswith('review effort')
|
||||
and not label.lower().startswith('possible security concern')
|
||||
]
|
||||
else:
|
||||
current_labels_filtered = []
|
||||
new_labels = review_labels + current_labels_filtered
|
||||
if (current_labels or review_labels) and sorted(new_labels) != sorted(current_labels):
|
||||
get_logger().info(f"Setting review labels:\n{review_labels + current_labels_filtered}")
|
||||
if (current_labels or review_labels) and sorted(new_labels) != sorted(
|
||||
current_labels
|
||||
):
|
||||
get_logger().info(
|
||||
f"Setting review labels:\n{review_labels + current_labels_filtered}"
|
||||
)
|
||||
self.git_provider.publish_labels(new_labels)
|
||||
else:
|
||||
get_logger().info(f"Review labels are already set:\n{review_labels + current_labels_filtered}")
|
||||
get_logger().info(
|
||||
f"Review labels are already set:\n{review_labels + current_labels_filtered}"
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to set review labels, error: {e}")
|
||||
|
||||
@ -406,5 +532,7 @@ class PRReviewer:
|
||||
self.git_provider.publish_comment("自动批准 PR")
|
||||
else:
|
||||
get_logger().info("Auto-approval option is disabled")
|
||||
self.git_provider.publish_comment("PR-Agent 的自动批准选项已禁用. "
|
||||
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")
|
||||
self.git_provider.publish_comment(
|
||||
"PR-Agent 的自动批准选项已禁用. "
|
||||
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)"
|
||||
)
|
||||
|
||||
@ -24,12 +24,16 @@ class PRSimilarIssue:
|
||||
self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan
|
||||
self.issue_url = issue_url
|
||||
self.git_provider = get_git_provider()()
|
||||
repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1])
|
||||
repo_name, issue_number = self.git_provider._parse_issue_url(
|
||||
issue_url.split('=')[-1]
|
||||
)
|
||||
self.git_provider.repo = repo_name
|
||||
self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
|
||||
self.token_handler = TokenHandler()
|
||||
repo_obj = self.git_provider.repo_obj
|
||||
repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
|
||||
repo_name_for_index = self.repo_name_for_index = (
|
||||
repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
|
||||
)
|
||||
index_name = self.index_name = "codium-ai-pr-agent-issues"
|
||||
|
||||
if get_settings().pr_similar_issue.vectordb == "pinecone":
|
||||
@ -38,17 +42,30 @@ class PRSimilarIssue:
|
||||
import pinecone
|
||||
from pinecone_datasets import Dataset, DatasetMetadata
|
||||
except:
|
||||
raise Exception("Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb")
|
||||
raise Exception(
|
||||
"Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb"
|
||||
)
|
||||
# assuming pinecone api key and environment are set in secrets file
|
||||
try:
|
||||
api_key = get_settings().pinecone.api_key
|
||||
environment = get_settings().pinecone.environment
|
||||
except Exception:
|
||||
if not self.cli_mode:
|
||||
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
|
||||
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
|
||||
issue_main.create_comment("Please set pinecone api key and environment in secrets file")
|
||||
raise Exception("Please set pinecone api key and environment in secrets file")
|
||||
(
|
||||
repo_name,
|
||||
original_issue_number,
|
||||
) = self.git_provider._parse_issue_url(
|
||||
self.issue_url.split('=')[-1]
|
||||
)
|
||||
issue_main = self.git_provider.repo_obj.get_issue(
|
||||
original_issue_number
|
||||
)
|
||||
issue_main.create_comment(
|
||||
"Please set pinecone api key and environment in secrets file"
|
||||
)
|
||||
raise Exception(
|
||||
"Please set pinecone api key and environment in secrets file"
|
||||
)
|
||||
|
||||
# check if index exists, and if repo is already indexed
|
||||
run_from_scratch = False
|
||||
@ -69,7 +86,9 @@ class PRSimilarIssue:
|
||||
upsert = True
|
||||
else:
|
||||
pinecone_index = pinecone.Index(index_name=index_name)
|
||||
res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict()
|
||||
res = pinecone_index.fetch(
|
||||
[f"example_issue_{repo_name_for_index}"]
|
||||
).to_dict()
|
||||
if res["vectors"]:
|
||||
upsert = False
|
||||
|
||||
@ -79,7 +98,9 @@ class PRSimilarIssue:
|
||||
get_logger().info('Getting issues...')
|
||||
issues = list(repo_obj.get_issues(state='all'))
|
||||
get_logger().info('Done')
|
||||
self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert)
|
||||
self._update_index_with_issues(
|
||||
issues, repo_name_for_index, upsert=upsert
|
||||
)
|
||||
else: # update index if needed
|
||||
pinecone_index = pinecone.Index(index_name=index_name)
|
||||
issues_to_update = []
|
||||
@ -105,7 +126,9 @@ class PRSimilarIssue:
|
||||
|
||||
if issues_to_update:
|
||||
get_logger().info(f'Updating index with {counter} new issues...')
|
||||
self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True)
|
||||
self._update_index_with_issues(
|
||||
issues_to_update, repo_name_for_index, upsert=True
|
||||
)
|
||||
else:
|
||||
get_logger().info('No new issues to update')
|
||||
|
||||
@ -133,7 +156,12 @@ class PRSimilarIssue:
|
||||
ingest = True
|
||||
else:
|
||||
self.table = self.db[index_name]
|
||||
res = self.table.search().limit(len(self.table)).where(f"id='example_issue_{repo_name_for_index}'").to_list()
|
||||
res = (
|
||||
self.table.search()
|
||||
.limit(len(self.table))
|
||||
.where(f"id='example_issue_{repo_name_for_index}'")
|
||||
.to_list()
|
||||
)
|
||||
get_logger().info("result: ", res)
|
||||
if res[0].get("vector"):
|
||||
ingest = False
|
||||
@ -145,7 +173,9 @@ class PRSimilarIssue:
|
||||
issues = list(repo_obj.get_issues(state='all'))
|
||||
get_logger().info('Done')
|
||||
|
||||
self._update_table_with_issues(issues, repo_name_for_index, ingest=ingest)
|
||||
self._update_table_with_issues(
|
||||
issues, repo_name_for_index, ingest=ingest
|
||||
)
|
||||
else: # update table if needed
|
||||
issues_to_update = []
|
||||
issues_paginated_list = repo_obj.get_issues(state='all')
|
||||
@ -156,7 +186,12 @@ class PRSimilarIssue:
|
||||
issue_str, comments, number = self._process_issue(issue)
|
||||
issue_key = f"issue_{number}"
|
||||
issue_id = issue_key + "." + "issue"
|
||||
res = self.table.search().limit(len(self.table)).where(f"id='{issue_id}'").to_list()
|
||||
res = (
|
||||
self.table.search()
|
||||
.limit(len(self.table))
|
||||
.where(f"id='{issue_id}'")
|
||||
.to_list()
|
||||
)
|
||||
is_new_issue = True
|
||||
for r in res:
|
||||
if r['metadata']['repo'] == repo_name_for_index:
|
||||
@ -170,14 +205,17 @@ class PRSimilarIssue:
|
||||
|
||||
if issues_to_update:
|
||||
get_logger().info(f'Updating index with {counter} new issues...')
|
||||
self._update_table_with_issues(issues_to_update, repo_name_for_index, ingest=True)
|
||||
self._update_table_with_issues(
|
||||
issues_to_update, repo_name_for_index, ingest=True
|
||||
)
|
||||
else:
|
||||
get_logger().info('No new issues to update')
|
||||
|
||||
|
||||
async def run(self):
|
||||
get_logger().info('Getting issue...')
|
||||
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
|
||||
repo_name, original_issue_number = self.git_provider._parse_issue_url(
|
||||
self.issue_url.split('=')[-1]
|
||||
)
|
||||
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
|
||||
issue_str, comments, number = self._process_issue(issue_main)
|
||||
openai.api_key = get_settings().openai.key
|
||||
@ -193,10 +231,12 @@ class PRSimilarIssue:
|
||||
|
||||
if get_settings().pr_similar_issue.vectordb == "pinecone":
|
||||
pinecone_index = pinecone.Index(index_name=self.index_name)
|
||||
res = pinecone_index.query(embeds[0],
|
||||
top_k=5,
|
||||
filter={"repo": self.repo_name_for_index},
|
||||
include_metadata=True).to_dict()
|
||||
res = pinecone_index.query(
|
||||
embeds[0],
|
||||
top_k=5,
|
||||
filter={"repo": self.repo_name_for_index},
|
||||
include_metadata=True,
|
||||
).to_dict()
|
||||
|
||||
for r in res['matches']:
|
||||
# skip example issue
|
||||
@ -214,14 +254,20 @@ class PRSimilarIssue:
|
||||
if issue_number not in relevant_issues_number_list:
|
||||
relevant_issues_number_list.append(issue_number)
|
||||
if 'comment' in r["id"]:
|
||||
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
|
||||
relevant_comment_number_list.append(
|
||||
int(r["id"].split('.')[1].split('_')[-1])
|
||||
)
|
||||
else:
|
||||
relevant_comment_number_list.append(-1)
|
||||
score_list.append(str("{:.2f}".format(r['score'])))
|
||||
get_logger().info('Done')
|
||||
|
||||
elif get_settings().pr_similar_issue.vectordb == "lancedb":
|
||||
res = self.table.search(embeds[0]).where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True).to_list()
|
||||
res = (
|
||||
self.table.search(embeds[0])
|
||||
.where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True)
|
||||
.to_list()
|
||||
)
|
||||
|
||||
for r in res:
|
||||
# skip example issue
|
||||
@ -240,10 +286,12 @@ class PRSimilarIssue:
|
||||
relevant_issues_number_list.append(issue_number)
|
||||
|
||||
if 'comment' in r["id"]:
|
||||
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
|
||||
relevant_comment_number_list.append(
|
||||
int(r["id"].split('.')[1].split('_')[-1])
|
||||
)
|
||||
else:
|
||||
relevant_comment_number_list.append(-1)
|
||||
score_list.append(str("{:.2f}".format(1-r['_distance'])))
|
||||
score_list.append(str("{:.2f}".format(1 - r['_distance'])))
|
||||
get_logger().info('Done')
|
||||
|
||||
get_logger().info('Publishing response...')
|
||||
@ -254,8 +302,12 @@ class PRSimilarIssue:
|
||||
title = issue.title
|
||||
url = issue.html_url
|
||||
if relevant_comment_number_list[i] != -1:
|
||||
url = list(issue.get_comments())[relevant_comment_number_list[i]].html_url
|
||||
similar_issues_str += f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
|
||||
url = list(issue.get_comments())[
|
||||
relevant_comment_number_list[i]
|
||||
].html_url
|
||||
similar_issues_str += (
|
||||
f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
|
||||
)
|
||||
if get_settings().config.publish_output:
|
||||
response = issue_main.create_comment(similar_issues_str)
|
||||
get_logger().info(similar_issues_str)
|
||||
@ -278,7 +330,7 @@ class PRSimilarIssue:
|
||||
example_issue_record = Record(
|
||||
id=f"example_issue_{repo_name_for_index}",
|
||||
text="example_issue",
|
||||
metadata=Metadata(repo=repo_name_for_index)
|
||||
metadata=Metadata(repo=repo_name_for_index),
|
||||
)
|
||||
corpus.append(example_issue_record)
|
||||
|
||||
@ -298,15 +350,20 @@ class PRSimilarIssue:
|
||||
issue_key = f"issue_{number}"
|
||||
username = issue.user.login
|
||||
created_at = str(issue.created_at)
|
||||
if len(issue_str) < 8000 or \
|
||||
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
|
||||
if len(issue_str) < 8000 or self.token_handler.count_tokens(
|
||||
issue_str
|
||||
) < get_max_tokens(
|
||||
MODEL
|
||||
): # fast reject first
|
||||
issue_record = Record(
|
||||
id=issue_key + "." + "issue",
|
||||
text=issue_str,
|
||||
metadata=Metadata(repo=repo_name_for_index,
|
||||
username=username,
|
||||
created_at=created_at,
|
||||
level=IssueLevel.ISSUE)
|
||||
metadata=Metadata(
|
||||
repo=repo_name_for_index,
|
||||
username=username,
|
||||
created_at=created_at,
|
||||
level=IssueLevel.ISSUE,
|
||||
),
|
||||
)
|
||||
corpus.append(issue_record)
|
||||
if comments:
|
||||
@ -316,15 +373,20 @@ class PRSimilarIssue:
|
||||
if num_words_comment < 10 or not isinstance(comment_body, str):
|
||||
continue
|
||||
|
||||
if len(comment_body) < 8000 or \
|
||||
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
|
||||
if (
|
||||
len(comment_body) < 8000
|
||||
or self.token_handler.count_tokens(comment_body)
|
||||
< MAX_TOKENS[MODEL]
|
||||
):
|
||||
comment_record = Record(
|
||||
id=issue_key + ".comment_" + str(j + 1),
|
||||
text=comment_body,
|
||||
metadata=Metadata(repo=repo_name_for_index,
|
||||
username=username, # use issue username for all comments
|
||||
created_at=created_at,
|
||||
level=IssueLevel.COMMENT)
|
||||
metadata=Metadata(
|
||||
repo=repo_name_for_index,
|
||||
username=username, # use issue username for all comments
|
||||
created_at=created_at,
|
||||
level=IssueLevel.COMMENT,
|
||||
),
|
||||
)
|
||||
corpus.append(comment_record)
|
||||
df = pd.DataFrame(corpus.dict()["documents"])
|
||||
@ -355,7 +417,9 @@ class PRSimilarIssue:
|
||||
environment = get_settings().pinecone.environment
|
||||
if not upsert:
|
||||
get_logger().info('Creating index from scratch...')
|
||||
ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment)
|
||||
ds.to_pinecone_index(
|
||||
self.index_name, api_key=api_key, environment=environment
|
||||
)
|
||||
time.sleep(15) # wait for pinecone to finalize indexing before querying
|
||||
else:
|
||||
get_logger().info('Upserting index...')
|
||||
@ -374,7 +438,7 @@ class PRSimilarIssue:
|
||||
example_issue_record = Record(
|
||||
id=f"example_issue_{repo_name_for_index}",
|
||||
text="example_issue",
|
||||
metadata=Metadata(repo=repo_name_for_index)
|
||||
metadata=Metadata(repo=repo_name_for_index),
|
||||
)
|
||||
corpus.append(example_issue_record)
|
||||
|
||||
@ -394,15 +458,20 @@ class PRSimilarIssue:
|
||||
issue_key = f"issue_{number}"
|
||||
username = issue.user.login
|
||||
created_at = str(issue.created_at)
|
||||
if len(issue_str) < 8000 or \
|
||||
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
|
||||
if len(issue_str) < 8000 or self.token_handler.count_tokens(
|
||||
issue_str
|
||||
) < get_max_tokens(
|
||||
MODEL
|
||||
): # fast reject first
|
||||
issue_record = Record(
|
||||
id=issue_key + "." + "issue",
|
||||
text=issue_str,
|
||||
metadata=Metadata(repo=repo_name_for_index,
|
||||
username=username,
|
||||
created_at=created_at,
|
||||
level=IssueLevel.ISSUE)
|
||||
metadata=Metadata(
|
||||
repo=repo_name_for_index,
|
||||
username=username,
|
||||
created_at=created_at,
|
||||
level=IssueLevel.ISSUE,
|
||||
),
|
||||
)
|
||||
corpus.append(issue_record)
|
||||
if comments:
|
||||
@ -412,15 +481,20 @@ class PRSimilarIssue:
|
||||
if num_words_comment < 10 or not isinstance(comment_body, str):
|
||||
continue
|
||||
|
||||
if len(comment_body) < 8000 or \
|
||||
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
|
||||
if (
|
||||
len(comment_body) < 8000
|
||||
or self.token_handler.count_tokens(comment_body)
|
||||
< MAX_TOKENS[MODEL]
|
||||
):
|
||||
comment_record = Record(
|
||||
id=issue_key + ".comment_" + str(j + 1),
|
||||
text=comment_body,
|
||||
metadata=Metadata(repo=repo_name_for_index,
|
||||
username=username, # use issue username for all comments
|
||||
created_at=created_at,
|
||||
level=IssueLevel.COMMENT)
|
||||
metadata=Metadata(
|
||||
repo=repo_name_for_index,
|
||||
username=username, # use issue username for all comments
|
||||
created_at=created_at,
|
||||
level=IssueLevel.COMMENT,
|
||||
),
|
||||
)
|
||||
corpus.append(comment_record)
|
||||
df = pd.DataFrame(corpus.dict()["documents"])
|
||||
@ -446,7 +520,9 @@ class PRSimilarIssue:
|
||||
|
||||
if not ingest:
|
||||
get_logger().info('Creating table from scratch...')
|
||||
self.table = self.db.create_table(self.index_name, data=df, mode="overwrite")
|
||||
self.table = self.db.create_table(
|
||||
self.index_name, data=df, mode="overwrite"
|
||||
)
|
||||
time.sleep(15)
|
||||
else:
|
||||
get_logger().info('Ingesting in Table...')
|
||||
|
||||
@ -20,13 +20,20 @@ CHANGELOG_LINES = 50
|
||||
|
||||
|
||||
class PRUpdateChangelog:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pr_url: str,
|
||||
cli_mode=False,
|
||||
args=None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||
):
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
|
||||
self.commit_changelog = (
|
||||
get_settings().pr_update_changelog.push_changelog_changes
|
||||
)
|
||||
self._get_changelog_file() # self.changelog_file_str
|
||||
|
||||
self.ai_handler = ai_handler()
|
||||
@ -47,15 +54,19 @@ class PRUpdateChangelog:
|
||||
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
}
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_update_changelog_prompt.system,
|
||||
get_settings().pr_update_changelog_prompt.user)
|
||||
self.token_handler = TokenHandler(
|
||||
self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_update_changelog_prompt.system,
|
||||
get_settings().pr_update_changelog_prompt.user,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
get_logger().info('Updating the changelog...')
|
||||
relevant_configs = {'pr_update_changelog': dict(get_settings().pr_update_changelog),
|
||||
'config': dict(get_settings().config)}
|
||||
relevant_configs = {
|
||||
'pr_update_changelog': dict(get_settings().pr_update_changelog),
|
||||
'config': dict(get_settings().config),
|
||||
}
|
||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||
|
||||
# currently only GitHub is supported for pushing changelog changes
|
||||
@ -74,13 +85,21 @@ class PRUpdateChangelog:
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True)
|
||||
|
||||
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
|
||||
await retry_with_fallback_models(
|
||||
self._prepare_prediction, model_type=ModelType.WEAK
|
||||
)
|
||||
|
||||
new_file_content, answer = self._prepare_changelog_update()
|
||||
|
||||
# Output the relevant configurations if enabled
|
||||
if get_settings().get('config', {}).get('output_relevant_configurations', False):
|
||||
answer += show_relevant_configurations(relevant_section='pr_update_changelog')
|
||||
if (
|
||||
get_settings()
|
||||
.get('config', {})
|
||||
.get('output_relevant_configurations', False)
|
||||
):
|
||||
answer += show_relevant_configurations(
|
||||
relevant_section='pr_update_changelog'
|
||||
)
|
||||
|
||||
get_logger().debug(f"PR output", artifact=answer)
|
||||
|
||||
@ -89,7 +108,9 @@ class PRUpdateChangelog:
|
||||
if self.commit_changelog:
|
||||
self._push_changelog_update(new_file_content, answer)
|
||||
else:
|
||||
self.git_provider.publish_comment(f"**Changelog updates:** 🔄\n\n{answer}")
|
||||
self.git_provider.publish_comment(
|
||||
f"**Changelog updates:** 🔄\n\n{answer}"
|
||||
)
|
||||
|
||||
async def _prepare_prediction(self, model: str):
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||
@ -106,10 +127,18 @@ class PRUpdateChangelog:
|
||||
if get_settings().pr_update_changelog.add_pr_link:
|
||||
variables["pr_link"] = self.git_provider.get_pr_url()
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
|
||||
system_prompt = environment.from_string(
|
||||
get_settings().pr_update_changelog_prompt.system
|
||||
).render(variables)
|
||||
user_prompt = environment.from_string(
|
||||
get_settings().pr_update_changelog_prompt.user
|
||||
).render(variables)
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model, system=system_prompt, user=user_prompt, temperature=get_settings().config.temperature)
|
||||
model=model,
|
||||
system=system_prompt,
|
||||
user=user_prompt,
|
||||
temperature=get_settings().config.temperature,
|
||||
)
|
||||
|
||||
# post-process the response
|
||||
response = response.strip()
|
||||
@ -134,8 +163,10 @@ class PRUpdateChangelog:
|
||||
new_file_content = answer
|
||||
|
||||
if not self.commit_changelog:
|
||||
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
|
||||
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
|
||||
answer += (
|
||||
"\n\n\n>to commit the new content to the CHANGELOG.md file, please type:"
|
||||
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
|
||||
)
|
||||
|
||||
return new_file_content, answer
|
||||
|
||||
@ -163,8 +194,7 @@ class PRUpdateChangelog:
|
||||
self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}")
|
||||
|
||||
def _get_default_changelog(self):
|
||||
example_changelog = \
|
||||
"""
|
||||
example_changelog = """
|
||||
Example:
|
||||
## <current_date>
|
||||
|
||||
|
||||
@ -7,14 +7,15 @@ from utils.pr_agent.log import get_logger
|
||||
|
||||
# Compile the regex pattern once, outside the function
|
||||
GITHUB_TICKET_PATTERN = re.compile(
|
||||
r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
|
||||
r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
|
||||
)
|
||||
|
||||
|
||||
def find_jira_tickets(text):
|
||||
# Regular expression patterns for JIRA tickets
|
||||
patterns = [
|
||||
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
|
||||
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket
|
||||
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b', # JIRA URL or just the ticket
|
||||
]
|
||||
|
||||
tickets = set()
|
||||
@ -32,7 +33,9 @@ def find_jira_tickets(text):
|
||||
return list(tickets)
|
||||
|
||||
|
||||
def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url_html='https://github.com'):
|
||||
def extract_ticket_links_from_pr_description(
|
||||
pr_description, repo_path, base_url_html='https://github.com'
|
||||
):
|
||||
"""
|
||||
Extract all ticket links from PR description
|
||||
"""
|
||||
@ -46,19 +49,27 @@ def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url
|
||||
github_tickets.add(match[0])
|
||||
elif match[1]: # Shorthand notation match: owner/repo#issue_number
|
||||
owner, repo, issue_number = match[2], match[3], match[4]
|
||||
github_tickets.add(f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}')
|
||||
github_tickets.add(
|
||||
f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}'
|
||||
)
|
||||
else: # #123 format
|
||||
issue_number = match[5][1:] # remove #
|
||||
if issue_number.isdigit() and len(issue_number) < 5 and repo_path:
|
||||
github_tickets.add(f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}')
|
||||
github_tickets.add(
|
||||
f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}'
|
||||
)
|
||||
|
||||
if len(github_tickets) > 3:
|
||||
get_logger().info(f"Too many tickets found in PR description: {len(github_tickets)}")
|
||||
get_logger().info(
|
||||
f"Too many tickets found in PR description: {len(github_tickets)}"
|
||||
)
|
||||
# Limit the number of tickets to 3
|
||||
github_tickets = set(list(github_tickets)[:3])
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error extracting tickets error= {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error extracting tickets error= {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
return list(github_tickets)
|
||||
|
||||
@ -68,19 +79,26 @@ async def extract_tickets(git_provider):
|
||||
try:
|
||||
if isinstance(git_provider, GithubProvider):
|
||||
user_description = git_provider.get_user_description()
|
||||
tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html)
|
||||
tickets = extract_ticket_links_from_pr_description(
|
||||
user_description, git_provider.repo, git_provider.base_url_html
|
||||
)
|
||||
tickets_content = []
|
||||
|
||||
if tickets:
|
||||
|
||||
for ticket in tickets:
|
||||
repo_name, original_issue_number = git_provider._parse_issue_url(ticket)
|
||||
repo_name, original_issue_number = git_provider._parse_issue_url(
|
||||
ticket
|
||||
)
|
||||
|
||||
try:
|
||||
issue_main = git_provider.repo_obj.get_issue(original_issue_number)
|
||||
issue_main = git_provider.repo_obj.get_issue(
|
||||
original_issue_number
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error getting main issue: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error getting main issue: {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
continue
|
||||
|
||||
issue_body_str = issue_main.body or ""
|
||||
@ -93,47 +111,66 @@ async def extract_tickets(git_provider):
|
||||
sub_issues = git_provider.fetch_sub_issues(ticket)
|
||||
for sub_issue_url in sub_issues:
|
||||
try:
|
||||
sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url)
|
||||
sub_issue = git_provider.repo_obj.get_issue(sub_issue_number)
|
||||
(
|
||||
sub_repo,
|
||||
sub_issue_number,
|
||||
) = git_provider._parse_issue_url(sub_issue_url)
|
||||
sub_issue = git_provider.repo_obj.get_issue(
|
||||
sub_issue_number
|
||||
)
|
||||
|
||||
sub_body = sub_issue.body or ""
|
||||
if len(sub_body) > MAX_TICKET_CHARACTERS:
|
||||
sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..."
|
||||
|
||||
sub_issues_content.append({
|
||||
'ticket_url': sub_issue_url,
|
||||
'title': sub_issue.title,
|
||||
'body': sub_body
|
||||
})
|
||||
sub_issues_content.append(
|
||||
{
|
||||
'ticket_url': sub_issue_url,
|
||||
'title': sub_issue.title,
|
||||
'body': sub_body,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}")
|
||||
get_logger().warning(
|
||||
f"Failed to fetch sub-issue content for {sub_issue_url}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}")
|
||||
get_logger().warning(
|
||||
f"Failed to fetch sub-issues for {ticket}: {e}"
|
||||
)
|
||||
|
||||
# Extract labels
|
||||
labels = []
|
||||
try:
|
||||
for label in issue_main.labels:
|
||||
labels.append(label.name if hasattr(label, 'name') else label)
|
||||
labels.append(
|
||||
label.name if hasattr(label, 'name') else label
|
||||
)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error extracting labels error= {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error extracting labels error= {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
tickets_content.append({
|
||||
'ticket_id': issue_main.number,
|
||||
'ticket_url': ticket,
|
||||
'title': issue_main.title,
|
||||
'body': issue_body_str,
|
||||
'labels': ", ".join(labels),
|
||||
'sub_issues': sub_issues_content # Store sub-issues content
|
||||
})
|
||||
tickets_content.append(
|
||||
{
|
||||
'ticket_id': issue_main.number,
|
||||
'ticket_url': ticket,
|
||||
'title': issue_main.title,
|
||||
'body': issue_body_str,
|
||||
'labels': ", ".join(labels),
|
||||
'sub_issues': sub_issues_content, # Store sub-issues content
|
||||
}
|
||||
)
|
||||
|
||||
return tickets_content
|
||||
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error extracting tickets error= {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
get_logger().error(
|
||||
f"Error extracting tickets error= {e}",
|
||||
artifact={"traceback": traceback.format_exc()},
|
||||
)
|
||||
|
||||
|
||||
async def extract_and_cache_pr_tickets(git_provider, vars):
|
||||
@ -154,8 +191,10 @@ async def extract_and_cache_pr_tickets(git_provider, vars):
|
||||
|
||||
related_tickets.append(ticket)
|
||||
|
||||
get_logger().info("Extracted tickets and sub-issues from PR description",
|
||||
artifact={"tickets": related_tickets})
|
||||
get_logger().info(
|
||||
"Extracted tickets and sub-issues from PR description",
|
||||
artifact={"tickets": related_tickets},
|
||||
)
|
||||
|
||||
vars['related_tickets'] = related_tickets
|
||||
get_settings().set('related_tickets', related_tickets)
|
||||
|
||||
13
config.ini
Normal file
13
config.ini
Normal file
@ -0,0 +1,13 @@
|
||||
[BASE]
|
||||
; 是否开启debug模式 0或1
|
||||
DEBUG = 0
|
||||
|
||||
[DATABASE]
|
||||
; 默认采用sqlite,线上需替换为pg
|
||||
DEFAULT = pg
|
||||
; postgres配置
|
||||
DB_NAME = pr_manager
|
||||
DB_USER = admin
|
||||
DB_PASSWORD = admin123456
|
||||
DB_HOST = 110.40.30.95
|
||||
DB_PORT = 5432
|
||||
@ -12,11 +12,18 @@ https://docs.djangoproject.com/en/5.1/ref/settings/
|
||||
|
||||
import os
|
||||
import sys
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
|
||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
CONFIG_NAME = BASE_DIR / "config.ini"
|
||||
|
||||
# 加载配置文件: 开发可加载config.local.ini
|
||||
_config = configparser.ConfigParser()
|
||||
_config.read(CONFIG_NAME, encoding="utf-8")
|
||||
|
||||
sys.path.insert(0, os.path.join(BASE_DIR, "apps"))
|
||||
sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
|
||||
|
||||
@ -27,7 +34,7 @@ sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
|
||||
SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76"
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = False
|
||||
DEBUG = bool(int(_config["BASE"].get("DEBUG", "1")))
|
||||
|
||||
ALLOWED_HOSTS = ["*"]
|
||||
|
||||
@ -44,7 +51,7 @@ INSTALLED_APPS = [
|
||||
"django.contrib.messages",
|
||||
"django.contrib.staticfiles",
|
||||
"public",
|
||||
"pr"
|
||||
"pr",
|
||||
]
|
||||
|
||||
# 配置安全秘钥
|
||||
@ -68,8 +75,7 @@ ROOT_URLCONF = "pr_manager.urls"
|
||||
TEMPLATES = [
|
||||
{
|
||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||
"DIRS": [BASE_DIR / 'templates']
|
||||
,
|
||||
"DIRS": [BASE_DIR / 'templates'],
|
||||
"APP_DIRS": True,
|
||||
"OPTIONS": {
|
||||
"context_processors": [
|
||||
@ -89,12 +95,22 @@ WSGI_APPLICATION = "pr_manager.wsgi.application"
|
||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"pg": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"NAME": _config["DATABASE"].get("DB_NAME", "chat_ai_v2"),
|
||||
"USER": _config["DATABASE"].get("DB_USER", "admin"),
|
||||
"PASSWORD": _config["DATABASE"].get("DB_PASSWORD", "admin123456"),
|
||||
"HOST": _config["DATABASE"].get("DB_HOST", "124.222.222.101"),
|
||||
"PORT": int(_config["DATABASE"].get("DB_PORT", "5432")),
|
||||
},
|
||||
"sqlite": {
|
||||
"ENGINE": "django.db.backends.sqlite3",
|
||||
"NAME": BASE_DIR / "db.sqlite3",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
DATABASES["default"] = DATABASES[_config["DATABASE"].get("DEFAULT", "sqlite")]
|
||||
|
||||
|
||||
# Password validation
|
||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user