代码优化,增强清晰度和可维护性。
This commit is contained in:
parent
1988a400c9
commit
de84796560
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,3 +14,4 @@ docs/.cache/
|
|||||||
db.sqlite3
|
db.sqlite3
|
||||||
#pr_agent/
|
#pr_agent/
|
||||||
static/admin/
|
static/admin/
|
||||||
|
config.local.ini
|
||||||
|
|||||||
2
Pipfile
2
Pipfile
@ -20,9 +20,9 @@ pygithub = "*"
|
|||||||
python-gitlab = "*"
|
python-gitlab = "*"
|
||||||
retry = "*"
|
retry = "*"
|
||||||
fastapi = "*"
|
fastapi = "*"
|
||||||
|
psycopg2-binary = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
|
|
||||||
[requires]
|
[requires]
|
||||||
python_version = "3.12"
|
python_version = "3.12"
|
||||||
|
|
||||||
|
|||||||
106
Pipfile.lock
generated
106
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "420206f7faa4351eabc368a83deae9b7ed9e50b0975ac63a46d6367e9920848b"
|
"sha256": "497c1ff8497659883faf8dcca407665df1b3a37f67720f64b139f9dec8202892"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
@ -169,20 +169,19 @@
|
|||||||
},
|
},
|
||||||
"boto3": {
|
"boto3": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:01015b38017876d79efd7273f35d9a4adfba505237159621365bed21b9b65eca",
|
"sha256:e58136d52d79425ce26c3c1578bf94d4b2e91ead55fed9f6950406ee9713e6af"
|
||||||
"sha256:03bd8c93b226f07d944fd6b022e11a307bff94ab6a21d51675d7e3ea81ee8424"
|
|
||||||
],
|
],
|
||||||
"index": "pip_conf_index_global",
|
"index": "pip_conf_index_global",
|
||||||
"markers": "python_version >= '3.8'",
|
"markers": "python_version >= '3.8'",
|
||||||
"version": "==1.37.0"
|
"version": "==1.37.2"
|
||||||
},
|
},
|
||||||
"botocore": {
|
"botocore": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:b129d091a8360b4152ab65327186bf4e250de827c4a9b7ddf40a72b1acf1f3c1",
|
"sha256:3f460f3c32cd6d747d5897a9cbde011bf1715abc7bf0a6ea6fdb0b812df63287",
|
||||||
"sha256:d01661f38c0edac87424344cdf4169f3ab9bc1bf1b677c8b230d025eb66c54a3"
|
"sha256:5f59b966f3cd0c8055ef6f7c2600f7db5f8218071d992e5f95da3f9156d4370f"
|
||||||
],
|
],
|
||||||
"markers": "python_version >= '3.8'",
|
"markers": "python_version >= '3.8'",
|
||||||
"version": "==1.37.0"
|
"version": "==1.37.2"
|
||||||
},
|
},
|
||||||
"certifi": {
|
"certifi": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
@ -460,12 +459,12 @@
|
|||||||
},
|
},
|
||||||
"django-import-export": {
|
"django-import-export": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:317842a64233025a277040129fb6792fc48fd39622c185b70bf8c18c393d708f",
|
"sha256:5514d09636e84e823a42cd5e79292f70f20d6d2feed117a145f5b64a5b44f168",
|
||||||
"sha256:ecb4e6cdb4790d69bce261f9cca1007ca19cb431bb5a950ba907898245c8817b"
|
"sha256:bd3fe0aa15a2bce9de4be1a2f882e2c4539fdbfdfa16f2052c98dd7aec0f085c"
|
||||||
],
|
],
|
||||||
"index": "pip_conf_index_global",
|
"index": "pip_conf_index_global",
|
||||||
"markers": "python_version >= '3.9'",
|
"markers": "python_version >= '3.9'",
|
||||||
"version": "==4.3.6"
|
"version": "==4.3.7"
|
||||||
},
|
},
|
||||||
"django-simpleui": {
|
"django-simpleui": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
@ -794,12 +793,12 @@
|
|||||||
},
|
},
|
||||||
"litellm": {
|
"litellm": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:02df5865f98ea9734a4d27ac7c33aad9a45c4015403d5c0797d3292ade3c5cb5",
|
"sha256:eaab989c090ccc094b41c3fdf27d1df7f6fb25e091ab0ce48e0f3079f1e51ff5",
|
||||||
"sha256:d241436ac0edf64ec57fb5686f8d84a25998a7e52213d9063adf87df8432701f"
|
"sha256:ff9137c008cdb421db32defb1fbd1ed546a95167de6d276c61b664582ed4ff60"
|
||||||
],
|
],
|
||||||
"index": "pip_conf_index_global",
|
"index": "pip_conf_index_global",
|
||||||
"markers": "python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'",
|
"markers": "python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'",
|
||||||
"version": "==1.61.16"
|
"version": "==1.61.17"
|
||||||
},
|
},
|
||||||
"loguru": {
|
"loguru": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
@ -1196,6 +1195,81 @@
|
|||||||
"markers": "python_version >= '3.6'",
|
"markers": "python_version >= '3.6'",
|
||||||
"version": "==7.0.0"
|
"version": "==7.0.0"
|
||||||
},
|
},
|
||||||
|
"psycopg2-binary": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:04392983d0bb89a8717772a193cfaac58871321e3ec69514e1c4e0d4957b5aff",
|
||||||
|
"sha256:056470c3dc57904bbf63d6f534988bafc4e970ffd50f6271fc4ee7daad9498a5",
|
||||||
|
"sha256:0ea8e3d0ae83564f2fc554955d327fa081d065c8ca5cc6d2abb643e2c9c1200f",
|
||||||
|
"sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5",
|
||||||
|
"sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0",
|
||||||
|
"sha256:19721ac03892001ee8fdd11507e6a2e01f4e37014def96379411ca99d78aeb2c",
|
||||||
|
"sha256:1a6784f0ce3fec4edc64e985865c17778514325074adf5ad8f80636cd029ef7c",
|
||||||
|
"sha256:2286791ececda3a723d1910441c793be44625d86d1a4e79942751197f4d30341",
|
||||||
|
"sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f",
|
||||||
|
"sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7",
|
||||||
|
"sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d",
|
||||||
|
"sha256:270934a475a0e4b6925b5f804e3809dd5f90f8613621d062848dd82f9cd62007",
|
||||||
|
"sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142",
|
||||||
|
"sha256:2ad26b467a405c798aaa1458ba09d7e2b6e5f96b1ce0ac15d82fd9f95dc38a92",
|
||||||
|
"sha256:2b3d2491d4d78b6b14f76881905c7a8a8abcf974aad4a8a0b065273a0ed7a2cb",
|
||||||
|
"sha256:2ce3e21dc3437b1d960521eca599d57408a695a0d3c26797ea0f72e834c7ffe5",
|
||||||
|
"sha256:30e34c4e97964805f715206c7b789d54a78b70f3ff19fbe590104b71c45600e5",
|
||||||
|
"sha256:3216ccf953b3f267691c90c6fe742e45d890d8272326b4a8b20850a03d05b7b8",
|
||||||
|
"sha256:32581b3020c72d7a421009ee1c6bf4a131ef5f0a968fab2e2de0c9d2bb4577f1",
|
||||||
|
"sha256:35958ec9e46432d9076286dda67942ed6d968b9c3a6a2fd62b48939d1d78bf68",
|
||||||
|
"sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73",
|
||||||
|
"sha256:3c18f74eb4386bf35e92ab2354a12c17e5eb4d9798e4c0ad3a00783eae7cd9f1",
|
||||||
|
"sha256:3c4745a90b78e51d9ba06e2088a2fe0c693ae19cc8cb051ccda44e8df8a6eb53",
|
||||||
|
"sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d",
|
||||||
|
"sha256:3e9c76f0ac6f92ecfc79516a8034a544926430f7b080ec5a0537bca389ee0906",
|
||||||
|
"sha256:48b338f08d93e7be4ab2b5f1dbe69dc5e9ef07170fe1f86514422076d9c010d0",
|
||||||
|
"sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2",
|
||||||
|
"sha256:512d29bb12608891e349af6a0cccedce51677725a921c07dba6342beaf576f9a",
|
||||||
|
"sha256:5a507320c58903967ef7384355a4da7ff3f28132d679aeb23572753cbf2ec10b",
|
||||||
|
"sha256:5c370b1e4975df846b0277b4deba86419ca77dbc25047f535b0bb03d1a544d44",
|
||||||
|
"sha256:6b269105e59ac96aba877c1707c600ae55711d9dcd3fc4b5012e4af68e30c648",
|
||||||
|
"sha256:6d4fa1079cab9018f4d0bd2db307beaa612b0d13ba73b5c6304b9fe2fb441ff7",
|
||||||
|
"sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f",
|
||||||
|
"sha256:73aa0e31fa4bb82578f3a6c74a73c273367727de397a7a0f07bd83cbea696baa",
|
||||||
|
"sha256:7559bce4b505762d737172556a4e6ea8a9998ecac1e39b5233465093e8cee697",
|
||||||
|
"sha256:79625966e176dc97ddabc142351e0409e28acf4660b88d1cf6adb876d20c490d",
|
||||||
|
"sha256:7a813c8bdbaaaab1f078014b9b0b13f5de757e2b5d9be6403639b298a04d218b",
|
||||||
|
"sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526",
|
||||||
|
"sha256:7f4152f8f76d2023aac16285576a9ecd2b11a9895373a1f10fd9db54b3ff06b4",
|
||||||
|
"sha256:7f5d859928e635fa3ce3477704acee0f667b3a3d3e4bb109f2b18d4005f38287",
|
||||||
|
"sha256:851485a42dbb0bdc1edcdabdb8557c09c9655dfa2ca0460ff210522e073e319e",
|
||||||
|
"sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673",
|
||||||
|
"sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0",
|
||||||
|
"sha256:8aabf1c1a04584c168984ac678a668094d831f152859d06e055288fa515e4d30",
|
||||||
|
"sha256:8aecc5e80c63f7459a1a2ab2c64df952051df196294d9f739933a9f6687e86b3",
|
||||||
|
"sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e",
|
||||||
|
"sha256:8de718c0e1c4b982a54b41779667242bc630b2197948405b7bd8ce16bcecac92",
|
||||||
|
"sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a",
|
||||||
|
"sha256:b5f86c56eeb91dc3135b3fd8a95dc7ae14c538a2f3ad77a19645cf55bab1799c",
|
||||||
|
"sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8",
|
||||||
|
"sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909",
|
||||||
|
"sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47",
|
||||||
|
"sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864",
|
||||||
|
"sha256:d00924255d7fc916ef66e4bf22f354a940c67179ad3fd7067d7a0a9c84d2fbfc",
|
||||||
|
"sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00",
|
||||||
|
"sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb",
|
||||||
|
"sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539",
|
||||||
|
"sha256:e5720a5d25e3b99cd0dc5c8a440570469ff82659bb09431c1439b92caf184d3b",
|
||||||
|
"sha256:e8b58f0a96e7a1e341fc894f62c1177a7c83febebb5ff9123b579418fdc8a481",
|
||||||
|
"sha256:e984839e75e0b60cfe75e351db53d6db750b00de45644c5d1f7ee5d1f34a1ce5",
|
||||||
|
"sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4",
|
||||||
|
"sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64",
|
||||||
|
"sha256:ecced182e935529727401b24d76634a357c71c9275b356efafd8a2a91ec07392",
|
||||||
|
"sha256:ee0e8c683a7ff25d23b55b11161c2663d4b099770f6085ff0a20d4505778d6b4",
|
||||||
|
"sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1",
|
||||||
|
"sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1",
|
||||||
|
"sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567",
|
||||||
|
"sha256:ffe8ed017e4ed70f68b7b371d84b7d4a790368db9203dfc2d222febd3a9c8863"
|
||||||
|
],
|
||||||
|
"index": "pip_conf_index_global",
|
||||||
|
"markers": "python_version >= '3.8'",
|
||||||
|
"version": "==2.9.10"
|
||||||
|
},
|
||||||
"py": {
|
"py": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719",
|
"sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719",
|
||||||
@ -1766,11 +1840,11 @@
|
|||||||
},
|
},
|
||||||
"s3transfer": {
|
"s3transfer": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:3b39185cb72f5acc77db1a58b6e25b977f28d20496b6e58d6813d75f464d632f",
|
"sha256:ca855bdeb885174b5ffa95b9913622459d4ad8e331fc98eb01e6d5eb6a30655d",
|
||||||
"sha256:be6ecb39fadd986ef1701097771f87e4d2f821f27f6071c872143884d2950fbc"
|
"sha256:edae4977e3a122445660c7c114bba949f9d191bae3b34a096f18a1c8c354527a"
|
||||||
],
|
],
|
||||||
"markers": "python_version >= '3.8'",
|
"markers": "python_version >= '3.8'",
|
||||||
"version": "==0.11.2"
|
"version": "==0.11.3"
|
||||||
},
|
},
|
||||||
"simplepro": {
|
"simplepro": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
|
|||||||
@ -34,7 +34,13 @@ class GitConfigAdmin(AjaxAdmin):
|
|||||||
class ProjectConfigAdmin(AjaxAdmin):
|
class ProjectConfigAdmin(AjaxAdmin):
|
||||||
"""Admin配置"""
|
"""Admin配置"""
|
||||||
|
|
||||||
list_display = ["project_id", "project_name", "project_secret", "commands", "is_enable"]
|
list_display = [
|
||||||
|
"project_id",
|
||||||
|
"project_name",
|
||||||
|
"project_secret",
|
||||||
|
"commands",
|
||||||
|
"is_enable",
|
||||||
|
]
|
||||||
readonly_fields = ["create_by", "delete_at", "detail"]
|
readonly_fields = ["create_by", "delete_at", "detail"]
|
||||||
top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>'
|
top_html = '<el-alert title="可配置多个项目!" type="success"></el-alert>'
|
||||||
|
|
||||||
|
|||||||
@ -16,4 +16,3 @@ class Command(BaseCommand):
|
|||||||
print("初始化AI配置已创建")
|
print("初始化AI配置已创建")
|
||||||
else:
|
else:
|
||||||
print("初始化AI配置已存在")
|
print("初始化AI配置已存在")
|
||||||
|
|
||||||
|
|||||||
@ -44,9 +44,7 @@ class GitConfig(BaseModel):
|
|||||||
null=True, blank=True, max_length=16, verbose_name="Git名称"
|
null=True, blank=True, max_length=16, verbose_name="Git名称"
|
||||||
)
|
)
|
||||||
git_type = fields.RadioField(
|
git_type = fields.RadioField(
|
||||||
choices=constant.GIT_TYPE,
|
choices=constant.GIT_TYPE, default=0, verbose_name="Git类型"
|
||||||
default=0,
|
|
||||||
verbose_name="Git类型"
|
|
||||||
)
|
)
|
||||||
git_url = fields.CharField(
|
git_url = fields.CharField(
|
||||||
null=True, blank=True, max_length=128, verbose_name="Git地址"
|
null=True, blank=True, max_length=128, verbose_name="Git地址"
|
||||||
@ -67,6 +65,7 @@ class ProjectConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
项目配置表
|
项目配置表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
git_config = fields.ForeignKey(
|
git_config = fields.ForeignKey(
|
||||||
GitConfig,
|
GitConfig,
|
||||||
null=True,
|
null=True,
|
||||||
@ -89,10 +88,7 @@ class ProjectConfig(BaseModel):
|
|||||||
max_length=256,
|
max_length=256,
|
||||||
verbose_name="默认命令",
|
verbose_name="默认命令",
|
||||||
)
|
)
|
||||||
is_enable = fields.SwitchField(
|
is_enable = fields.SwitchField(default=True, verbose_name="是否启用")
|
||||||
default=True,
|
|
||||||
verbose_name="是否启用"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = "项目配置"
|
verbose_name = "项目配置"
|
||||||
@ -106,6 +102,7 @@ class ProjectHistory(BaseModel):
|
|||||||
"""
|
"""
|
||||||
项目历史表
|
项目历史表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
project = fields.ForeignKey(
|
project = fields.ForeignKey(
|
||||||
ProjectConfig,
|
ProjectConfig,
|
||||||
null=True,
|
null=True,
|
||||||
@ -128,9 +125,7 @@ class ProjectHistory(BaseModel):
|
|||||||
mr_title = fields.CharField(
|
mr_title = fields.CharField(
|
||||||
null=True, blank=True, max_length=256, verbose_name="MR标题"
|
null=True, blank=True, max_length=256, verbose_name="MR标题"
|
||||||
)
|
)
|
||||||
source_data = models.JSONField(
|
source_data = models.JSONField(null=True, blank=True, verbose_name="源数据")
|
||||||
null=True, blank=True, verbose_name="源数据"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = "项目历史"
|
verbose_name = "项目历史"
|
||||||
|
|||||||
@ -12,12 +12,7 @@ from utils import constant
|
|||||||
|
|
||||||
|
|
||||||
def load_project_config(
|
def load_project_config(
|
||||||
git_url,
|
git_url, access_token, project_secret, openai_api_base, openai_key, llm_model
|
||||||
access_token,
|
|
||||||
project_secret,
|
|
||||||
openai_api_base,
|
|
||||||
openai_key,
|
|
||||||
llm_model
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
加载项目配置
|
加载项目配置
|
||||||
@ -36,12 +31,11 @@ def load_project_config(
|
|||||||
"secret": project_secret,
|
"secret": project_secret,
|
||||||
"openai_api_base": openai_api_base,
|
"openai_api_base": openai_api_base,
|
||||||
"openai_key": openai_key,
|
"openai_key": openai_key,
|
||||||
"llm_model": llm_model
|
"llm_model": llm_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class WebHookView(View):
|
class WebHookView(View):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def select_git_provider(git_type):
|
def select_git_provider(git_type):
|
||||||
"""
|
"""
|
||||||
@ -82,7 +76,9 @@ class WebHookView(View):
|
|||||||
project_config = provider.get_project_config(project_id=project_id)
|
project_config = provider.get_project_config(project_id=project_id)
|
||||||
|
|
||||||
# Token 校验
|
# Token 校验
|
||||||
provider.check_secret(request_headers=headers, project_secret=project_config.get("project_secret"))
|
provider.check_secret(
|
||||||
|
request_headers=headers, project_secret=project_config.get("project_secret")
|
||||||
|
)
|
||||||
|
|
||||||
provider.get_merge_request(
|
provider.get_merge_request(
|
||||||
request_data=json_data,
|
request_data=json_data,
|
||||||
@ -91,11 +87,13 @@ class WebHookView(View):
|
|||||||
api_base=project_config.get("api_base"),
|
api_base=project_config.get("api_base"),
|
||||||
api_key=project_config.get("api_key"),
|
api_key=project_config.get("api_key"),
|
||||||
llm_model=project_config.get("llm_model"),
|
llm_model=project_config.get("llm_model"),
|
||||||
project_commands=project_config.get("commands")
|
project_commands=project_config.get("commands"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录请求日志: 目前仅记录合并日志
|
# 记录请求日志: 目前仅记录合并日志
|
||||||
if json_data.get('object_kind') == 'merge_request':
|
if json_data.get('object_kind') == 'merge_request':
|
||||||
provider.save_pr_agent_log(request_data=json_data, project_id=project_config.get("project_id"))
|
provider.save_pr_agent_log(
|
||||||
|
request_data=json_data, project_id=project_config.get("project_id")
|
||||||
|
)
|
||||||
|
|
||||||
return JsonResponse(status=200, data={"status": "ignored"})
|
return JsonResponse(status=200, data={"status": "ignored"})
|
||||||
|
|||||||
@ -1,8 +1,4 @@
|
|||||||
GIT_TYPE = (
|
GIT_TYPE = ((0, "gitlab"), (1, "github"), (2, "gitea"))
|
||||||
(0, "gitlab"),
|
|
||||||
(1, "github"),
|
|
||||||
(2, "gitea")
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_COMMANDS = (
|
DEFAULT_COMMANDS = (
|
||||||
("/review", "/review"),
|
("/review", "/review"),
|
||||||
@ -10,11 +6,7 @@ DEFAULT_COMMANDS = (
|
|||||||
("/improve_code", "/improve_code"),
|
("/improve_code", "/improve_code"),
|
||||||
)
|
)
|
||||||
|
|
||||||
UA_TYPE = {
|
UA_TYPE = {"GitLab": "gitlab", "GitHub": "github", "Go-http-client": "gitea"}
|
||||||
"GitLab": "gitlab",
|
|
||||||
"GitHub": "github",
|
|
||||||
"Go-http-client": "gitea"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_git_type_from_ua(ua_value):
|
def get_git_type_from_ua(ua_value):
|
||||||
|
|||||||
@ -16,14 +16,14 @@ class GitProvider(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_merge_request(
|
def get_merge_request(
|
||||||
self,
|
self,
|
||||||
request_data,
|
request_data,
|
||||||
git_url,
|
git_url,
|
||||||
access_token,
|
access_token,
|
||||||
api_base,
|
api_base,
|
||||||
api_key,
|
api_key,
|
||||||
llm_model,
|
llm_model,
|
||||||
project_commands
|
project_commands,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -33,7 +33,6 @@ class GitProvider(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class GitLabProvider(GitProvider):
|
class GitLabProvider(GitProvider):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_secret(request_headers, project_secret):
|
def check_secret(request_headers, project_secret):
|
||||||
"""
|
"""
|
||||||
@ -79,18 +78,18 @@ class GitLabProvider(GitProvider):
|
|||||||
"access_token": git_config.access_token,
|
"access_token": git_config.access_token,
|
||||||
"project_secret": project_config.project_secret,
|
"project_secret": project_config.project_secret,
|
||||||
"commands": project_config.commands.split(","),
|
"commands": project_config.commands.split(","),
|
||||||
"project_id": project_config.id
|
"project_id": project_config.id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_merge_request(
|
def get_merge_request(
|
||||||
self,
|
self,
|
||||||
request_data,
|
request_data,
|
||||||
git_url,
|
git_url,
|
||||||
access_token,
|
access_token,
|
||||||
api_base,
|
api_base,
|
||||||
api_key,
|
api_key,
|
||||||
llm_model,
|
llm_model,
|
||||||
project_commands,
|
project_commands,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
实现GitLab Merge Request获取逻辑
|
实现GitLab Merge Request获取逻辑
|
||||||
@ -124,7 +123,10 @@ class GitLabProvider(GitProvider):
|
|||||||
self.run_command(mr_url, project_commands)
|
self.run_command(mr_url, project_commands)
|
||||||
# 数据库留存
|
# 数据库留存
|
||||||
return JsonResponse(status=200, data={"status": "review started"})
|
return JsonResponse(status=200, data={"status": "review started"})
|
||||||
return JsonResponse(status=400, data={"error": "Merge request URL not found or action not open"})
|
return JsonResponse(
|
||||||
|
status=400,
|
||||||
|
data={"error": "Merge request URL not found or action not open"},
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_pr_agent_log(request_data, project_id):
|
def save_pr_agent_log(request_data, project_id):
|
||||||
@ -134,13 +136,19 @@ class GitLabProvider(GitProvider):
|
|||||||
:param project_id:
|
:param project_id:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if request_data.get('object_attributes', {}).get("source_branch") and request_data.get('object_attributes', {}).get("target_branch"):
|
if request_data.get('object_attributes', {}).get(
|
||||||
|
"source_branch"
|
||||||
|
) and request_data.get('object_attributes', {}).get("target_branch"):
|
||||||
models.ProjectHistory.objects.create(
|
models.ProjectHistory.objects.create(
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
project_url=request_data.get("project", {}).get("web_url"),
|
project_url=request_data.get("project", {}).get("web_url"),
|
||||||
mr_url=request_data.get('object_attributes', {}).get("url"),
|
mr_url=request_data.get('object_attributes', {}).get("url"),
|
||||||
source_branch=request_data.get('object_attributes', {}).get("source_branch"),
|
source_branch=request_data.get('object_attributes', {}).get(
|
||||||
target_branch=request_data.get('object_attributes', {}).get("target_branch"),
|
"source_branch"
|
||||||
|
),
|
||||||
|
target_branch=request_data.get('object_attributes', {}).get(
|
||||||
|
"target_branch"
|
||||||
|
),
|
||||||
mr_title=request_data.get('object_attributes', {}).get("title"),
|
mr_title=request_data.get('object_attributes', {}).get("title"),
|
||||||
source_data=request_data,
|
source_data=request_data,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -80,14 +80,20 @@ class PRAgent:
|
|||||||
if action == "answer":
|
if action == "answer":
|
||||||
if notify:
|
if notify:
|
||||||
notify()
|
notify()
|
||||||
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
|
await PRReviewer(
|
||||||
|
pr_url, is_answer=True, args=args, ai_handler=self.ai_handler
|
||||||
|
).run()
|
||||||
elif action == "auto_review":
|
elif action == "auto_review":
|
||||||
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
|
await PRReviewer(
|
||||||
|
pr_url, is_auto=True, args=args, ai_handler=self.ai_handler
|
||||||
|
).run()
|
||||||
elif action in command2class:
|
elif action in command2class:
|
||||||
if notify:
|
if notify:
|
||||||
notify()
|
notify()
|
||||||
|
|
||||||
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
|
await command2class[action](
|
||||||
|
pr_url, ai_handler=self.ai_handler, args=args
|
||||||
|
).run()
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -88,7 +88,7 @@ USER_MESSAGE_ONLY_MODELS = [
|
|||||||
"deepseek/deepseek-reasoner",
|
"deepseek/deepseek-reasoner",
|
||||||
"o1-mini",
|
"o1-mini",
|
||||||
"o1-mini-2024-09-12",
|
"o1-mini-2024-09-12",
|
||||||
"o1-preview"
|
"o1-preview",
|
||||||
]
|
]
|
||||||
|
|
||||||
NO_SUPPORT_TEMPERATURE_MODELS = [
|
NO_SUPPORT_TEMPERATURE_MODELS = [
|
||||||
@ -99,5 +99,5 @@ NO_SUPPORT_TEMPERATURE_MODELS = [
|
|||||||
"o1-2024-12-17",
|
"o1-2024-12-17",
|
||||||
"o3-mini",
|
"o3-mini",
|
||||||
"o3-mini-2025-01-31",
|
"o3-mini-2025-01-31",
|
||||||
"o1-preview"
|
"o1-preview",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -16,7 +16,14 @@ class BaseAiHandler(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
system: str,
|
||||||
|
user: str,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
img_path: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
This method should be implemented to return a chat completion from the AI model.
|
This method should be implemented to return a chat completion from the AI model.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -34,9 +34,16 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
|||||||
"""
|
"""
|
||||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
|
||||||
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
@retry(
|
||||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
tries=OPENAI_RETRIES,
|
||||||
|
delay=2,
|
||||||
|
backoff=2,
|
||||||
|
jitter=(1, 3),
|
||||||
|
)
|
||||||
|
async def chat_completion(
|
||||||
|
self, model: str, system: str, user: str, temperature: float = 0.2
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
messages = [SystemMessage(content=system), HumanMessage(content=user)]
|
messages = [SystemMessage(content=system), HumanMessage(content=user)]
|
||||||
|
|
||||||
@ -45,7 +52,7 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
|||||||
finish_reason = "completed"
|
finish_reason = "completed"
|
||||||
return resp.content, finish_reason
|
return resp.content, finish_reason
|
||||||
|
|
||||||
except (Exception) as e:
|
except Exception as e:
|
||||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -66,7 +73,10 @@ class LangChainOpenAIHandler(BaseAiHandler):
|
|||||||
if openai_api_base is None or len(openai_api_base) == 0:
|
if openai_api_base is None or len(openai_api_base) == 0:
|
||||||
return ChatOpenAI(openai_api_key=get_settings().openai.key)
|
return ChatOpenAI(openai_api_key=get_settings().openai.key)
|
||||||
else:
|
else:
|
||||||
return ChatOpenAI(openai_api_key=get_settings().openai.key, openai_api_base=openai_api_base)
|
return ChatOpenAI(
|
||||||
|
openai_api_key=get_settings().openai.key,
|
||||||
|
openai_api_base=openai_api_base,
|
||||||
|
)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if getattr(e, "name"):
|
if getattr(e, "name"):
|
||||||
raise ValueError(f"OpenAI {e.name} is required") from e
|
raise ValueError(f"OpenAI {e.name} is required") from e
|
||||||
|
|||||||
@ -36,9 +36,14 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
elif 'OPENAI_API_KEY' not in os.environ:
|
elif 'OPENAI_API_KEY' not in os.environ:
|
||||||
litellm.api_key = "dummy_key"
|
litellm.api_key = "dummy_key"
|
||||||
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
|
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
|
||||||
assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete"
|
assert (
|
||||||
|
get_settings().aws.AWS_SECRET_ACCESS_KEY
|
||||||
|
and get_settings().aws.AWS_REGION_NAME
|
||||||
|
), "AWS credentials are incomplete"
|
||||||
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
|
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
|
||||||
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
|
os.environ[
|
||||||
|
"AWS_SECRET_ACCESS_KEY"
|
||||||
|
] = get_settings().aws.AWS_SECRET_ACCESS_KEY
|
||||||
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
|
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
|
||||||
if get_settings().get("litellm.use_client"):
|
if get_settings().get("litellm.use_client"):
|
||||||
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
|
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
|
||||||
@ -73,14 +78,19 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
litellm.replicate_key = get_settings().replicate.key
|
litellm.replicate_key = get_settings().replicate.key
|
||||||
if get_settings().get("HUGGINGFACE.KEY", None):
|
if get_settings().get("HUGGINGFACE.KEY", None):
|
||||||
litellm.huggingface_key = get_settings().huggingface.key
|
litellm.huggingface_key = get_settings().huggingface.key
|
||||||
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
|
if (
|
||||||
|
get_settings().get("HUGGINGFACE.API_BASE", None)
|
||||||
|
and 'huggingface' in get_settings().config.model
|
||||||
|
):
|
||||||
litellm.api_base = get_settings().huggingface.api_base
|
litellm.api_base = get_settings().huggingface.api_base
|
||||||
self.api_base = get_settings().huggingface.api_base
|
self.api_base = get_settings().huggingface.api_base
|
||||||
if get_settings().get("OLLAMA.API_BASE", None):
|
if get_settings().get("OLLAMA.API_BASE", None):
|
||||||
litellm.api_base = get_settings().ollama.api_base
|
litellm.api_base = get_settings().ollama.api_base
|
||||||
self.api_base = get_settings().ollama.api_base
|
self.api_base = get_settings().ollama.api_base
|
||||||
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
|
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
|
||||||
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
|
self.repetition_penalty = float(
|
||||||
|
get_settings().huggingface.repetition_penalty
|
||||||
|
)
|
||||||
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
|
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
|
||||||
litellm.vertex_project = get_settings().vertexai.vertex_project
|
litellm.vertex_project = get_settings().vertexai.vertex_project
|
||||||
litellm.vertex_location = get_settings().get(
|
litellm.vertex_location = get_settings().get(
|
||||||
@ -89,7 +99,9 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
# Google AI Studio
|
# Google AI Studio
|
||||||
# SEE https://docs.litellm.ai/docs/providers/gemini
|
# SEE https://docs.litellm.ai/docs/providers/gemini
|
||||||
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
|
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
|
||||||
os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key
|
os.environ[
|
||||||
|
"GEMINI_API_KEY"
|
||||||
|
] = get_settings().google_ai_studio.gemini_api_key
|
||||||
|
|
||||||
# Support deepseek models
|
# Support deepseek models
|
||||||
if get_settings().get("DEEPSEEK.KEY", None):
|
if get_settings().get("DEEPSEEK.KEY", None):
|
||||||
@ -140,27 +152,35 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
git_provider = get_settings().config.git_provider
|
git_provider = get_settings().config.git_provider
|
||||||
|
|
||||||
metadata = dict()
|
metadata = dict()
|
||||||
callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback
|
callbacks = (
|
||||||
|
litellm.success_callback
|
||||||
|
+ litellm.failure_callback
|
||||||
|
+ litellm.service_callback
|
||||||
|
)
|
||||||
if "langfuse" in callbacks:
|
if "langfuse" in callbacks:
|
||||||
metadata.update({
|
metadata.update(
|
||||||
"trace_name": command,
|
{
|
||||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
"trace_name": command,
|
||||||
"trace_metadata": {
|
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||||
"command": command,
|
"trace_metadata": {
|
||||||
"pr_url": pr_url,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if "langsmith" in callbacks:
|
|
||||||
metadata.update({
|
|
||||||
"run_name": command,
|
|
||||||
"tags": [git_provider, command, f'version:{get_version()}'],
|
|
||||||
"extra": {
|
|
||||||
"metadata": {
|
|
||||||
"command": command,
|
"command": command,
|
||||||
"pr_url": pr_url,
|
"pr_url": pr_url,
|
||||||
}
|
},
|
||||||
},
|
}
|
||||||
})
|
)
|
||||||
|
if "langsmith" in callbacks:
|
||||||
|
metadata.update(
|
||||||
|
{
|
||||||
|
"run_name": command,
|
||||||
|
"tags": [git_provider, command, f'version:{get_version()}'],
|
||||||
|
"extra": {
|
||||||
|
"metadata": {
|
||||||
|
"command": command,
|
||||||
|
"pr_url": pr_url,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Adding the captured logs to the kwargs
|
# Adding the captured logs to the kwargs
|
||||||
kwargs["metadata"] = metadata
|
kwargs["metadata"] = metadata
|
||||||
@ -175,10 +195,19 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError
|
retry=retry_if_exception_type(
|
||||||
stop=stop_after_attempt(OPENAI_RETRIES)
|
(openai.APIError, openai.APIConnectionError, openai.APITimeoutError)
|
||||||
|
), # No retry on RateLimitError
|
||||||
|
stop=stop_after_attempt(OPENAI_RETRIES),
|
||||||
)
|
)
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
system: str,
|
||||||
|
user: str,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
img_path: str = None,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
resp, finish_reason = None, None
|
resp, finish_reason = None, None
|
||||||
deployment_id = self.deployment_id
|
deployment_id = self.deployment_id
|
||||||
@ -187,8 +216,12 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
if 'claude' in model and not system:
|
if 'claude' in model and not system:
|
||||||
system = "No system prompt provided"
|
system = "No system prompt provided"
|
||||||
get_logger().warning(
|
get_logger().warning(
|
||||||
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.")
|
"Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error."
|
||||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
]
|
||||||
|
|
||||||
if img_path:
|
if img_path:
|
||||||
try:
|
try:
|
||||||
@ -201,14 +234,21 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error fetching image: {img_path}", e)
|
get_logger().error(f"Error fetching image: {img_path}", e)
|
||||||
return f"Error fetching image: {img_path}", "error"
|
return f"Error fetching image: {img_path}", "error"
|
||||||
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
|
messages[1]["content"] = [
|
||||||
{"type": "image_url", "image_url": {"url": img_path}}]
|
{"type": "text", "text": messages[1]["content"]},
|
||||||
|
{"type": "image_url", "image_url": {"url": img_path}},
|
||||||
|
]
|
||||||
|
|
||||||
# Currently, some models do not support a separate system and user prompts
|
# Currently, some models do not support a separate system and user prompts
|
||||||
if model in self.user_message_only_models or get_settings().config.custom_reasoning_model:
|
if (
|
||||||
|
model in self.user_message_only_models
|
||||||
|
or get_settings().config.custom_reasoning_model
|
||||||
|
):
|
||||||
user = f"{system}\n\n\n{user}"
|
user = f"{system}\n\n\n{user}"
|
||||||
system = ""
|
system = ""
|
||||||
get_logger().info(f"Using model {model}, combining system and user prompts")
|
get_logger().info(
|
||||||
|
f"Using model {model}, combining system and user prompts"
|
||||||
|
)
|
||||||
messages = [{"role": "user", "content": user}]
|
messages = [{"role": "user", "content": user}]
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@ -227,7 +267,10 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Add temperature only if model supports it
|
# Add temperature only if model supports it
|
||||||
if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model:
|
if (
|
||||||
|
model not in self.no_support_temperature_models
|
||||||
|
and not get_settings().config.custom_reasoning_model
|
||||||
|
):
|
||||||
kwargs["temperature"] = temperature
|
kwargs["temperature"] = temperature
|
||||||
|
|
||||||
if get_settings().litellm.get("enable_callbacks", False):
|
if get_settings().litellm.get("enable_callbacks", False):
|
||||||
@ -235,7 +278,9 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
|
|
||||||
seed = get_settings().config.get("seed", -1)
|
seed = get_settings().config.get("seed", -1)
|
||||||
if temperature > 0 and seed >= 0:
|
if temperature > 0 and seed >= 0:
|
||||||
raise ValueError(f"Seed ({seed}) is not supported with temperature ({temperature}) > 0")
|
raise ValueError(
|
||||||
|
f"Seed ({seed}) is not supported with temperature ({temperature}) > 0"
|
||||||
|
)
|
||||||
elif seed >= 0:
|
elif seed >= 0:
|
||||||
get_logger().info(f"Using fixed seed of {seed}")
|
get_logger().info(f"Using fixed seed of {seed}")
|
||||||
kwargs["seed"] = seed
|
kwargs["seed"] = seed
|
||||||
@ -253,10 +298,10 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
except (openai.APIError, openai.APITimeoutError) as e:
|
except (openai.APIError, openai.APITimeoutError) as e:
|
||||||
get_logger().warning(f"Error during LLM inference: {e}")
|
get_logger().warning(f"Error during LLM inference: {e}")
|
||||||
raise
|
raise
|
||||||
except (openai.RateLimitError) as e:
|
except openai.RateLimitError as e:
|
||||||
get_logger().error(f"Rate limit error during LLM inference: {e}")
|
get_logger().error(f"Rate limit error during LLM inference: {e}")
|
||||||
raise
|
raise
|
||||||
except (Exception) as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Unknown error during LLM inference: {e}")
|
get_logger().warning(f"Unknown error during LLM inference: {e}")
|
||||||
raise openai.APIError from e
|
raise openai.APIError from e
|
||||||
if response is None or len(response["choices"]) == 0:
|
if response is None or len(response["choices"]) == 0:
|
||||||
@ -267,7 +312,9 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
get_logger().debug(f"\nAI response:\n{resp}")
|
get_logger().debug(f"\nAI response:\n{resp}")
|
||||||
|
|
||||||
# log the full response for debugging
|
# log the full response for debugging
|
||||||
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
|
response_log = self.prepare_logs(
|
||||||
|
response, system, user, resp, finish_reason
|
||||||
|
)
|
||||||
get_logger().debug("Full_response", artifact=response_log)
|
get_logger().debug("Full_response", artifact=response_log)
|
||||||
|
|
||||||
# for CLI debugging
|
# for CLI debugging
|
||||||
|
|||||||
@ -37,13 +37,23 @@ class OpenAIHandler(BaseAiHandler):
|
|||||||
"""
|
"""
|
||||||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
|
||||||
@retry(exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
@retry(
|
||||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
exceptions=(APIError, Timeout, AttributeError, RateLimitError),
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
tries=OPENAI_RETRIES,
|
||||||
|
delay=2,
|
||||||
|
backoff=2,
|
||||||
|
jitter=(1, 3),
|
||||||
|
)
|
||||||
|
async def chat_completion(
|
||||||
|
self, model: str, system: str, user: str, temperature: float = 0.2
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
get_logger().info("System: ", system)
|
get_logger().info("System: ", system)
|
||||||
get_logger().info("User: ", user)
|
get_logger().info("User: ", user)
|
||||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
messages = [
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
]
|
||||||
client = AsyncOpenAI()
|
client = AsyncOpenAI()
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
@ -53,15 +63,21 @@ class OpenAIHandler(BaseAiHandler):
|
|||||||
resp = chat_completion.choices[0].message.content
|
resp = chat_completion.choices[0].message.content
|
||||||
finish_reason = chat_completion.choices[0].finish_reason
|
finish_reason = chat_completion.choices[0].finish_reason
|
||||||
usage = chat_completion.usage
|
usage = chat_completion.usage
|
||||||
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
|
get_logger().info(
|
||||||
model=model, usage=usage)
|
"AI response",
|
||||||
|
response=resp,
|
||||||
|
messages=messages,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
model=model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
return resp, finish_reason
|
return resp, finish_reason
|
||||||
except (APIError, Timeout) as e:
|
except (APIError, Timeout) as e:
|
||||||
get_logger().error("Error during OpenAI inference: ", e)
|
get_logger().error("Error during OpenAI inference: ", e)
|
||||||
raise
|
raise
|
||||||
except (RateLimitError) as e:
|
except RateLimitError as e:
|
||||||
get_logger().error("Rate limit error during OpenAI inference: ", e)
|
get_logger().error("Rate limit error during OpenAI inference: ", e)
|
||||||
raise
|
raise
|
||||||
except (Exception) as e:
|
except Exception as e:
|
||||||
get_logger().error("Unknown error during OpenAI inference: ", e)
|
get_logger().error("Unknown error during OpenAI inference: ", e)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
class CliArgs:
|
class CliArgs:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_user_args(args: list) -> (bool, str):
|
def validate_user_args(args: list) -> (bool, str):
|
||||||
@ -23,12 +24,12 @@ class CliArgs:
|
|||||||
for arg in args:
|
for arg in args:
|
||||||
if arg.startswith('--'):
|
if arg.startswith('--'):
|
||||||
arg_word = arg.lower()
|
arg_word = arg.lower()
|
||||||
arg_word = arg_word.replace('__', '.') # replace double underscore with dot, e.g. --openai__key -> --openai.key
|
arg_word = arg_word.replace(
|
||||||
|
'__', '.'
|
||||||
|
) # replace double underscore with dot, e.g. --openai__key -> --openai.key
|
||||||
for forbidden_arg_word in forbidden_cli_args:
|
for forbidden_arg_word in forbidden_cli_args:
|
||||||
if forbidden_arg_word in arg_word:
|
if forbidden_arg_word in arg_word:
|
||||||
return False, forbidden_arg_word
|
return False, forbidden_arg_word
|
||||||
return True, ""
|
return True, ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import re
|
|||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
|
|
||||||
def filter_ignored(files, platform = 'github'):
|
def filter_ignored(files, platform='github'):
|
||||||
"""
|
"""
|
||||||
Filter out files that match the ignore patterns.
|
Filter out files that match the ignore patterns.
|
||||||
"""
|
"""
|
||||||
@ -15,7 +15,9 @@ def filter_ignored(files, platform = 'github'):
|
|||||||
if isinstance(patterns, str):
|
if isinstance(patterns, str):
|
||||||
patterns = [patterns]
|
patterns = [patterns]
|
||||||
glob_setting = get_settings().ignore.glob
|
glob_setting = get_settings().ignore.glob
|
||||||
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
|
if isinstance(
|
||||||
|
glob_setting, str
|
||||||
|
): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
|
||||||
glob_setting = glob_setting.strip('[]').split(",")
|
glob_setting = glob_setting.strip('[]').split(",")
|
||||||
patterns += [fnmatch.translate(glob) for glob in glob_setting]
|
patterns += [fnmatch.translate(glob) for glob in glob_setting]
|
||||||
|
|
||||||
@ -31,7 +33,9 @@ def filter_ignored(files, platform = 'github'):
|
|||||||
if files and isinstance(files, list):
|
if files and isinstance(files, list):
|
||||||
for r in compiled_patterns:
|
for r in compiled_patterns:
|
||||||
if platform == 'github':
|
if platform == 'github':
|
||||||
files = [f for f in files if (f.filename and not r.match(f.filename))]
|
files = [
|
||||||
|
f for f in files if (f.filename and not r.match(f.filename))
|
||||||
|
]
|
||||||
elif platform == 'bitbucket':
|
elif platform == 'bitbucket':
|
||||||
# files = [f for f in files if (f.new.path and not r.match(f.new.path))]
|
# files = [f for f in files if (f.new.path and not r.match(f.new.path))]
|
||||||
files_o = []
|
files_o = []
|
||||||
@ -49,10 +53,18 @@ def filter_ignored(files, platform = 'github'):
|
|||||||
# files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))]
|
# files = [f for f in files if (f['new_path'] and not r.match(f['new_path']))]
|
||||||
files_o = []
|
files_o = []
|
||||||
for f in files:
|
for f in files:
|
||||||
if 'new_path' in f and f['new_path'] and not r.match(f['new_path']):
|
if (
|
||||||
|
'new_path' in f
|
||||||
|
and f['new_path']
|
||||||
|
and not r.match(f['new_path'])
|
||||||
|
):
|
||||||
files_o.append(f)
|
files_o.append(f)
|
||||||
continue
|
continue
|
||||||
if 'old_path' in f and f['old_path'] and not r.match(f['old_path']):
|
if (
|
||||||
|
'old_path' in f
|
||||||
|
and f['old_path']
|
||||||
|
and not r.match(f['old_path'])
|
||||||
|
):
|
||||||
files_o.append(f)
|
files_o.append(f)
|
||||||
continue
|
continue
|
||||||
files = files_o
|
files = files_o
|
||||||
|
|||||||
@ -8,9 +8,18 @@ from utils.pr_agent.config_loader import get_settings
|
|||||||
from utils.pr_agent.log import get_logger
|
from utils.pr_agent.log import get_logger
|
||||||
|
|
||||||
|
|
||||||
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
|
def extend_patch(
|
||||||
patch_extra_lines_after=0, filename: str = "") -> str:
|
original_file_str,
|
||||||
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
|
patch_str,
|
||||||
|
patch_extra_lines_before=0,
|
||||||
|
patch_extra_lines_after=0,
|
||||||
|
filename: str = "",
|
||||||
|
) -> str:
|
||||||
|
if (
|
||||||
|
not patch_str
|
||||||
|
or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0)
|
||||||
|
or not original_file_str
|
||||||
|
):
|
||||||
return patch_str
|
return patch_str
|
||||||
|
|
||||||
original_file_str = decode_if_bytes(original_file_str)
|
original_file_str = decode_if_bytes(original_file_str)
|
||||||
@ -21,10 +30,17 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
|
|||||||
return patch_str
|
return patch_str
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extended_patch_str = process_patch_lines(patch_str, original_file_str,
|
extended_patch_str = process_patch_lines(
|
||||||
patch_extra_lines_before, patch_extra_lines_after)
|
patch_str,
|
||||||
|
original_file_str,
|
||||||
|
patch_extra_lines_before,
|
||||||
|
patch_extra_lines_after,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
|
get_logger().warning(
|
||||||
|
f"Failed to extend patch: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
return patch_str
|
return patch_str
|
||||||
|
|
||||||
return extended_patch_str
|
return extended_patch_str
|
||||||
@ -48,13 +64,19 @@ def decode_if_bytes(original_file_str):
|
|||||||
def should_skip_patch(filename):
|
def should_skip_patch(filename):
|
||||||
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
|
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
|
||||||
if patch_extension_skip_types and filename:
|
if patch_extension_skip_types and filename:
|
||||||
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
|
return any(
|
||||||
|
filename.endswith(skip_type) for skip_type in patch_extension_skip_types
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
|
def process_patch_lines(
|
||||||
|
patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after
|
||||||
|
):
|
||||||
allow_dynamic_context = get_settings().config.allow_dynamic_context
|
allow_dynamic_context = get_settings().config.allow_dynamic_context
|
||||||
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
|
patch_extra_lines_before_dynamic = (
|
||||||
|
get_settings().config.max_extra_lines_before_dynamic_context
|
||||||
|
)
|
||||||
|
|
||||||
original_lines = original_file_str.splitlines()
|
original_lines = original_file_str.splitlines()
|
||||||
len_original_lines = len(original_lines)
|
len_original_lines = len(original_lines)
|
||||||
@ -63,59 +85,122 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
|
|||||||
|
|
||||||
is_valid_hunk = True
|
is_valid_hunk = True
|
||||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||||
RE_HUNK_HEADER = re.compile(
|
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
|
||||||
try:
|
try:
|
||||||
for i,line in enumerate(patch_lines):
|
for i, line in enumerate(patch_lines):
|
||||||
if line.startswith('@@'):
|
if line.startswith('@@'):
|
||||||
match = RE_HUNK_HEADER.match(line)
|
match = RE_HUNK_HEADER.match(line)
|
||||||
# identify hunk header
|
# identify hunk header
|
||||||
if match:
|
if match:
|
||||||
# finish processing previous hunk
|
# finish processing previous hunk
|
||||||
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
|
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
|
||||||
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
|
delta_lines = [
|
||||||
|
f' {line}'
|
||||||
|
for line in original_lines[
|
||||||
|
start1
|
||||||
|
+ size1
|
||||||
|
- 1 : start1
|
||||||
|
+ size1
|
||||||
|
- 1
|
||||||
|
+ patch_extra_lines_after
|
||||||
|
]
|
||||||
|
]
|
||||||
extended_patch_lines.extend(delta_lines)
|
extended_patch_lines.extend(delta_lines)
|
||||||
|
|
||||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
section_header, size1, size2, start1, start2 = extract_hunk_headers(
|
||||||
|
match
|
||||||
|
)
|
||||||
|
|
||||||
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1)
|
is_valid_hunk = check_if_hunk_lines_matches_to_file(
|
||||||
|
i, original_lines, patch_lines, start1
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_valid_hunk and (
|
||||||
|
patch_extra_lines_before > 0 or patch_extra_lines_after > 0
|
||||||
|
):
|
||||||
|
|
||||||
if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0):
|
|
||||||
def _calc_context_limits(patch_lines_before):
|
def _calc_context_limits(patch_lines_before):
|
||||||
extended_start1 = max(1, start1 - patch_lines_before)
|
extended_start1 = max(1, start1 - patch_lines_before)
|
||||||
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
|
extended_size1 = (
|
||||||
|
size1
|
||||||
|
+ (start1 - extended_start1)
|
||||||
|
+ patch_extra_lines_after
|
||||||
|
)
|
||||||
extended_start2 = max(1, start2 - patch_lines_before)
|
extended_start2 = max(1, start2 - patch_lines_before)
|
||||||
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
|
extended_size2 = (
|
||||||
if extended_start1 - 1 + extended_size1 > len_original_lines:
|
size2
|
||||||
|
+ (start2 - extended_start2)
|
||||||
|
+ patch_extra_lines_after
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
extended_start1 - 1 + extended_size1
|
||||||
|
> len_original_lines
|
||||||
|
):
|
||||||
# we cannot extend beyond the original file
|
# we cannot extend beyond the original file
|
||||||
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines
|
delta_cap = (
|
||||||
|
extended_start1
|
||||||
|
- 1
|
||||||
|
+ extended_size1
|
||||||
|
- len_original_lines
|
||||||
|
)
|
||||||
extended_size1 = max(extended_size1 - delta_cap, size1)
|
extended_size1 = max(extended_size1 - delta_cap, size1)
|
||||||
extended_size2 = max(extended_size2 - delta_cap, size2)
|
extended_size2 = max(extended_size2 - delta_cap, size2)
|
||||||
return extended_start1, extended_size1, extended_start2, extended_size2
|
return (
|
||||||
|
extended_start1,
|
||||||
|
extended_size1,
|
||||||
|
extended_start2,
|
||||||
|
extended_size2,
|
||||||
|
)
|
||||||
|
|
||||||
if allow_dynamic_context:
|
if allow_dynamic_context:
|
||||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
(
|
||||||
_calc_context_limits(patch_extra_lines_before_dynamic)
|
extended_start1,
|
||||||
lines_before = original_lines[extended_start1 - 1:start1 - 1]
|
extended_size1,
|
||||||
|
extended_start2,
|
||||||
|
extended_size2,
|
||||||
|
) = _calc_context_limits(patch_extra_lines_before_dynamic)
|
||||||
|
lines_before = original_lines[
|
||||||
|
extended_start1 - 1 : start1 - 1
|
||||||
|
]
|
||||||
found_header = False
|
found_header = False
|
||||||
for i, line, in enumerate(lines_before):
|
for (
|
||||||
|
i,
|
||||||
|
line,
|
||||||
|
) in enumerate(lines_before):
|
||||||
if section_header in line:
|
if section_header in line:
|
||||||
found_header = True
|
found_header = True
|
||||||
# Update start and size in one line each
|
# Update start and size in one line each
|
||||||
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i
|
extended_start1, extended_start2 = (
|
||||||
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i
|
extended_start1 + i,
|
||||||
|
extended_start2 + i,
|
||||||
|
)
|
||||||
|
extended_size1, extended_size2 = (
|
||||||
|
extended_size1 - i,
|
||||||
|
extended_size2 - i,
|
||||||
|
)
|
||||||
# get_logger().debug(f"Found section header in line {i} before the hunk")
|
# get_logger().debug(f"Found section header in line {i} before the hunk")
|
||||||
section_header = ''
|
section_header = ''
|
||||||
break
|
break
|
||||||
if not found_header:
|
if not found_header:
|
||||||
# get_logger().debug(f"Section header not found in the extra lines before the hunk")
|
# get_logger().debug(f"Section header not found in the extra lines before the hunk")
|
||||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
(
|
||||||
_calc_context_limits(patch_extra_lines_before)
|
extended_start1,
|
||||||
|
extended_size1,
|
||||||
|
extended_start2,
|
||||||
|
extended_size2,
|
||||||
|
) = _calc_context_limits(patch_extra_lines_before)
|
||||||
else:
|
else:
|
||||||
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
(
|
||||||
_calc_context_limits(patch_extra_lines_before)
|
extended_start1,
|
||||||
|
extended_size1,
|
||||||
|
extended_start2,
|
||||||
|
extended_size2,
|
||||||
|
) = _calc_context_limits(patch_extra_lines_before)
|
||||||
|
|
||||||
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
|
delta_lines = [
|
||||||
|
f' {line}'
|
||||||
|
for line in original_lines[extended_start1 - 1 : start1 - 1]
|
||||||
|
]
|
||||||
|
|
||||||
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
|
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
|
||||||
if section_header and not allow_dynamic_context:
|
if section_header and not allow_dynamic_context:
|
||||||
@ -132,17 +217,23 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
|
|||||||
extended_patch_lines.append('')
|
extended_patch_lines.append('')
|
||||||
extended_patch_lines.append(
|
extended_patch_lines.append(
|
||||||
f'@@ -{extended_start1},{extended_size1} '
|
f'@@ -{extended_start1},{extended_size1} '
|
||||||
f'+{extended_start2},{extended_size2} @@ {section_header}')
|
f'+{extended_start2},{extended_size2} @@ {section_header}'
|
||||||
|
)
|
||||||
extended_patch_lines.extend(delta_lines) # one to zero based
|
extended_patch_lines.extend(delta_lines) # one to zero based
|
||||||
continue
|
continue
|
||||||
extended_patch_lines.append(line)
|
extended_patch_lines.append(line)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
|
get_logger().warning(
|
||||||
|
f"Failed to extend patch: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
return patch_str
|
return patch_str
|
||||||
|
|
||||||
# finish processing last hunk
|
# finish processing last hunk
|
||||||
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
|
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
|
||||||
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
|
delta_lines = original_lines[
|
||||||
|
start1 + size1 - 1 : start1 + size1 - 1 + patch_extra_lines_after
|
||||||
|
]
|
||||||
# add space at the beginning of each extra line
|
# add space at the beginning of each extra line
|
||||||
delta_lines = [f' {line}' for line in delta_lines]
|
delta_lines = [f' {line}' for line in delta_lines]
|
||||||
extended_patch_lines.extend(delta_lines)
|
extended_patch_lines.extend(delta_lines)
|
||||||
@ -158,11 +249,14 @@ def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
|
|||||||
"""
|
"""
|
||||||
is_valid_hunk = True
|
is_valid_hunk = True
|
||||||
try:
|
try:
|
||||||
if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file
|
if (
|
||||||
|
i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' '
|
||||||
|
): # an existing line in the file
|
||||||
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
|
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
|
||||||
is_valid_hunk = False
|
is_valid_hunk = False
|
||||||
get_logger().error(
|
get_logger().error(
|
||||||
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content")
|
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content"
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return is_valid_hunk
|
return is_valid_hunk
|
||||||
@ -195,8 +289,7 @@ def omit_deletion_hunks(patch_lines) -> str:
|
|||||||
added_patched = []
|
added_patched = []
|
||||||
add_hunk = False
|
add_hunk = False
|
||||||
inside_hunk = False
|
inside_hunk = False
|
||||||
RE_HUNK_HEADER = re.compile(
|
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
|
|
||||||
|
|
||||||
for line in patch_lines:
|
for line in patch_lines:
|
||||||
if line.startswith('@@'):
|
if line.startswith('@@'):
|
||||||
@ -221,8 +314,13 @@ def omit_deletion_hunks(patch_lines) -> str:
|
|||||||
return '\n'.join(added_patched)
|
return '\n'.join(added_patched)
|
||||||
|
|
||||||
|
|
||||||
def handle_patch_deletions(patch: str, original_file_content_str: str,
|
def handle_patch_deletions(
|
||||||
new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str:
|
patch: str,
|
||||||
|
original_file_content_str: str,
|
||||||
|
new_file_content_str: str,
|
||||||
|
file_name: str,
|
||||||
|
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Handle entire file or deletion patches.
|
Handle entire file or deletion patches.
|
||||||
|
|
||||||
@ -239,11 +337,13 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
|||||||
str: The modified patch with deletion hunks omitted.
|
str: The modified patch with deletion hunks omitted.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN):
|
if not new_file_content_str and (
|
||||||
|
edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN
|
||||||
|
):
|
||||||
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
# logic for handling deleted files - don't show patch, just show that the file was deleted
|
||||||
if get_settings().config.verbosity_level > 0:
|
if get_settings().config.verbosity_level > 0:
|
||||||
get_logger().info(f"Processing file: {file_name}, minimizing deletion file")
|
get_logger().info(f"Processing file: {file_name}, minimizing deletion file")
|
||||||
patch = None # file was deleted
|
patch = None # file was deleted
|
||||||
else:
|
else:
|
||||||
patch_lines = patch.splitlines()
|
patch_lines = patch.splitlines()
|
||||||
patch_new = omit_deletion_hunks(patch_lines)
|
patch_new = omit_deletion_hunks(patch_lines)
|
||||||
@ -256,35 +356,35 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
|
|||||||
|
|
||||||
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
|
Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of
|
||||||
the file.
|
the file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
patch (str): The patch string to be converted.
|
patch (str): The patch string to be converted.
|
||||||
file: An object containing the filename of the file being patched.
|
file: An object containing the filename of the file being patched.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A string with line numbers for each hunk, indicating the new and old content of the file.
|
str: A string with line numbers for each hunk, indicating the new and old content of the file.
|
||||||
|
|
||||||
example output:
|
example output:
|
||||||
## src/file.ts
|
## src/file.ts
|
||||||
__new hunk__
|
__new hunk__
|
||||||
881 line1
|
881 line1
|
||||||
882 line2
|
882 line2
|
||||||
883 line3
|
883 line3
|
||||||
887 + line4
|
887 + line4
|
||||||
888 + line5
|
888 + line5
|
||||||
889 line6
|
889 line6
|
||||||
890 line7
|
890 line7
|
||||||
...
|
...
|
||||||
__old hunk__
|
__old hunk__
|
||||||
line1
|
line1
|
||||||
line2
|
line2
|
||||||
- line3
|
- line3
|
||||||
- line4
|
- line4
|
||||||
line5
|
line5
|
||||||
line6
|
line6
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
# if the file was deleted, return a message indicating that the file was deleted
|
# if the file was deleted, return a message indicating that the file was deleted
|
||||||
if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED:
|
if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED:
|
||||||
@ -292,8 +392,7 @@ __old hunk__
|
|||||||
|
|
||||||
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
|
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
|
||||||
patch_lines = patch.splitlines()
|
patch_lines = patch.splitlines()
|
||||||
RE_HUNK_HEADER = re.compile(
|
RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
|
||||||
new_content_lines = []
|
new_content_lines = []
|
||||||
old_content_lines = []
|
old_content_lines = []
|
||||||
match = None
|
match = None
|
||||||
@ -307,20 +406,32 @@ __old hunk__
|
|||||||
if line.startswith('@@'):
|
if line.startswith('@@'):
|
||||||
header_line = line
|
header_line = line
|
||||||
match = RE_HUNK_HEADER.match(line)
|
match = RE_HUNK_HEADER.match(line)
|
||||||
if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines
|
if match and (
|
||||||
|
new_content_lines or old_content_lines
|
||||||
|
): # found a new hunk, split the previous lines
|
||||||
if prev_header_line:
|
if prev_header_line:
|
||||||
patch_with_lines_str += f'\n{prev_header_line}\n'
|
patch_with_lines_str += f'\n{prev_header_line}\n'
|
||||||
is_plus_lines = is_minus_lines = False
|
is_plus_lines = is_minus_lines = False
|
||||||
if new_content_lines:
|
if new_content_lines:
|
||||||
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
|
is_plus_lines = any(
|
||||||
|
[line.startswith('+') for line in new_content_lines]
|
||||||
|
)
|
||||||
if old_content_lines:
|
if old_content_lines:
|
||||||
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
|
is_minus_lines = any(
|
||||||
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
[line.startswith('-') for line in old_content_lines]
|
||||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
)
|
||||||
|
if (
|
||||||
|
is_plus_lines or is_minus_lines
|
||||||
|
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||||
|
patch_with_lines_str = (
|
||||||
|
patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
||||||
|
)
|
||||||
for i, line_new in enumerate(new_content_lines):
|
for i, line_new in enumerate(new_content_lines):
|
||||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||||
if is_minus_lines:
|
if is_minus_lines:
|
||||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
|
patch_with_lines_str = (
|
||||||
|
patch_with_lines_str.rstrip() + '\n__old hunk__\n'
|
||||||
|
)
|
||||||
for line_old in old_content_lines:
|
for line_old in old_content_lines:
|
||||||
patch_with_lines_str += f"{line_old}\n"
|
patch_with_lines_str += f"{line_old}\n"
|
||||||
new_content_lines = []
|
new_content_lines = []
|
||||||
@ -335,8 +446,12 @@ __old hunk__
|
|||||||
elif line.startswith('-'):
|
elif line.startswith('-'):
|
||||||
old_content_lines.append(line)
|
old_content_lines.append(line)
|
||||||
else:
|
else:
|
||||||
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
|
if (
|
||||||
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
|
not line and line_i
|
||||||
|
): # if this line is empty and the next line is a hunk header, skip it
|
||||||
|
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith(
|
||||||
|
'@@'
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
elif line_i + 1 == len(patch_lines):
|
elif line_i + 1 == len(patch_lines):
|
||||||
continue
|
continue
|
||||||
@ -351,7 +466,9 @@ __old hunk__
|
|||||||
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
|
is_plus_lines = any([line.startswith('+') for line in new_content_lines])
|
||||||
if old_content_lines:
|
if old_content_lines:
|
||||||
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
|
is_minus_lines = any([line.startswith('-') for line in old_content_lines])
|
||||||
if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
if (
|
||||||
|
is_plus_lines or is_minus_lines
|
||||||
|
): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused
|
||||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n'
|
||||||
for i, line_new in enumerate(new_content_lines):
|
for i, line_new in enumerate(new_content_lines):
|
||||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||||
@ -363,13 +480,16 @@ __old hunk__
|
|||||||
return patch_with_lines_str.rstrip()
|
return patch_with_lines_str.rstrip()
|
||||||
|
|
||||||
|
|
||||||
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]:
|
def extract_hunk_lines_from_patch(
|
||||||
|
patch: str, file_name, line_start, line_end, side
|
||||||
|
) -> tuple[str, str]:
|
||||||
try:
|
try:
|
||||||
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
|
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
|
||||||
selected_lines = ""
|
selected_lines = ""
|
||||||
patch_lines = patch.splitlines()
|
patch_lines = patch.splitlines()
|
||||||
RE_HUNK_HEADER = re.compile(
|
RE_HUNK_HEADER = re.compile(
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
|
||||||
|
)
|
||||||
match = None
|
match = None
|
||||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||||
skip_hunk = False
|
skip_hunk = False
|
||||||
@ -385,7 +505,9 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
|
|||||||
|
|
||||||
match = RE_HUNK_HEADER.match(line)
|
match = RE_HUNK_HEADER.match(line)
|
||||||
|
|
||||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
section_header, size1, size2, start1, start2 = extract_hunk_headers(
|
||||||
|
match
|
||||||
|
)
|
||||||
|
|
||||||
# check if line range is in this hunk
|
# check if line range is in this hunk
|
||||||
if side.lower() == 'left':
|
if side.lower() == 'left':
|
||||||
@ -400,15 +522,26 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
|
|||||||
patch_with_lines_str += f'\n{header_line}\n'
|
patch_with_lines_str += f'\n{header_line}\n'
|
||||||
|
|
||||||
elif not skip_hunk:
|
elif not skip_hunk:
|
||||||
if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end:
|
if (
|
||||||
|
side.lower() == 'right'
|
||||||
|
and line_start <= start2 + selected_lines_num <= line_end
|
||||||
|
):
|
||||||
selected_lines += line + '\n'
|
selected_lines += line + '\n'
|
||||||
if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end:
|
if (
|
||||||
|
side.lower() == 'left'
|
||||||
|
and start1 <= selected_lines_num + start1 <= line_end
|
||||||
|
):
|
||||||
selected_lines += line + '\n'
|
selected_lines += line + '\n'
|
||||||
patch_with_lines_str += line + '\n'
|
patch_with_lines_str += line + '\n'
|
||||||
if not line.startswith('-'): # currently we don't support /ask line for deleted lines
|
if not line.startswith(
|
||||||
|
'-'
|
||||||
|
): # currently we don't support /ask line for deleted lines
|
||||||
selected_lines_num += 1
|
selected_lines_num += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()})
|
get_logger().error(
|
||||||
|
f"Failed to extract hunk lines from patch: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|
||||||
return patch_with_lines_str.rstrip(), selected_lines.rstrip()
|
return patch_with_lines_str.rstrip(), selected_lines.rstrip()
|
||||||
|
|||||||
@ -9,10 +9,14 @@ def filter_bad_extensions(files):
|
|||||||
bad_extensions = get_settings().bad_extensions.default
|
bad_extensions = get_settings().bad_extensions.default
|
||||||
if get_settings().config.use_extra_bad_extensions:
|
if get_settings().config.use_extra_bad_extensions:
|
||||||
bad_extensions += get_settings().bad_extensions.extra
|
bad_extensions += get_settings().bad_extensions.extra
|
||||||
return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)]
|
return [
|
||||||
|
f
|
||||||
|
for f in files
|
||||||
|
if f.filename is not None and is_valid_file(f.filename, bad_extensions)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def is_valid_file(filename:str, bad_extensions=None) -> bool:
|
def is_valid_file(filename: str, bad_extensions=None) -> bool:
|
||||||
if not filename:
|
if not filename:
|
||||||
return False
|
return False
|
||||||
if not bad_extensions:
|
if not bad_extensions:
|
||||||
@ -27,12 +31,16 @@ def sort_files_by_main_languages(languages: Dict, files: list):
|
|||||||
Sort files by their main language, put the files that are in the main language first and the rest files after
|
Sort files by their main language, put the files that are in the main language first and the rest files after
|
||||||
"""
|
"""
|
||||||
# sort languages by their size
|
# sort languages by their size
|
||||||
languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)]
|
languages_sorted_list = [
|
||||||
|
k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)
|
||||||
|
]
|
||||||
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
|
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
|
||||||
# get all extensions for the languages
|
# get all extensions for the languages
|
||||||
main_extensions = []
|
main_extensions = []
|
||||||
language_extension_map_org = get_settings().language_extension_map_org
|
language_extension_map_org = get_settings().language_extension_map_org
|
||||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
language_extension_map = {
|
||||||
|
k.lower(): v for k, v in language_extension_map_org.items()
|
||||||
|
}
|
||||||
for language in languages_sorted_list:
|
for language in languages_sorted_list:
|
||||||
if language.lower() in language_extension_map:
|
if language.lower() in language_extension_map:
|
||||||
main_extensions.append(language_extension_map[language.lower()])
|
main_extensions.append(language_extension_map[language.lower()])
|
||||||
@ -62,7 +70,9 @@ def sort_files_by_main_languages(languages: Dict, files: list):
|
|||||||
if extension_str in extensions:
|
if extension_str in extensions:
|
||||||
tmp.append(file)
|
tmp.append(file)
|
||||||
else:
|
else:
|
||||||
if (file.filename not in rest_files) and (extension_str not in main_extensions_flat):
|
if (file.filename not in rest_files) and (
|
||||||
|
extension_str not in main_extensions_flat
|
||||||
|
):
|
||||||
rest_files[file.filename] = file
|
rest_files[file.filename] = file
|
||||||
if len(tmp) > 0:
|
if len(tmp) > 0:
|
||||||
files_sorted.append({"language": lang, "files": tmp})
|
files_sorted.append({"language": lang, "files": tmp})
|
||||||
|
|||||||
@ -7,18 +7,28 @@ from github import RateLimitExceededException
|
|||||||
|
|
||||||
from utils.pr_agent.algo.file_filter import filter_ignored
|
from utils.pr_agent.algo.file_filter import filter_ignored
|
||||||
from utils.pr_agent.algo.git_patch_processing import (
|
from utils.pr_agent.algo.git_patch_processing import (
|
||||||
convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions)
|
convert_to_hunks_with_lines_numbers,
|
||||||
|
extend_patch,
|
||||||
|
handle_patch_deletions,
|
||||||
|
)
|
||||||
from utils.pr_agent.algo.language_handler import sort_files_by_main_languages
|
from utils.pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||||
from utils.pr_agent.algo.types import EDIT_TYPE
|
from utils.pr_agent.algo.types import EDIT_TYPE
|
||||||
from utils.pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model
|
from utils.pr_agent.algo.utils import (
|
||||||
|
ModelType,
|
||||||
|
clip_tokens,
|
||||||
|
get_max_tokens,
|
||||||
|
get_weak_model,
|
||||||
|
)
|
||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||||
from utils.pr_agent.log import get_logger
|
from utils.pr_agent.log import get_logger
|
||||||
|
|
||||||
DELETED_FILES_ = "Deleted files:\n"
|
DELETED_FILES_ = "Deleted files:\n"
|
||||||
|
|
||||||
MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n"
|
MORE_MODIFIED_FILES_ = (
|
||||||
|
"Additional modified files (insufficient token budget to process):\n"
|
||||||
|
)
|
||||||
|
|
||||||
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
|
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
|
||||||
|
|
||||||
@ -29,45 +39,59 @@ MAX_EXTRA_LINES = 10
|
|||||||
|
|
||||||
def cap_and_log_extra_lines(value, direction) -> int:
|
def cap_and_log_extra_lines(value, direction) -> int:
|
||||||
if value > MAX_EXTRA_LINES:
|
if value > MAX_EXTRA_LINES:
|
||||||
get_logger().warning(f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}")
|
get_logger().warning(
|
||||||
|
f"patch_extra_lines_{direction} was {value}, capping to {MAX_EXTRA_LINES}"
|
||||||
|
)
|
||||||
return MAX_EXTRA_LINES
|
return MAX_EXTRA_LINES
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
def get_pr_diff(
|
||||||
model: str,
|
git_provider: GitProvider,
|
||||||
add_line_numbers_to_hunks: bool = False,
|
token_handler: TokenHandler,
|
||||||
disable_extra_lines: bool = False,
|
model: str,
|
||||||
large_pr_handling=False,
|
add_line_numbers_to_hunks: bool = False,
|
||||||
return_remaining_files=False):
|
disable_extra_lines: bool = False,
|
||||||
|
large_pr_handling=False,
|
||||||
|
return_remaining_files=False,
|
||||||
|
):
|
||||||
if disable_extra_lines:
|
if disable_extra_lines:
|
||||||
PATCH_EXTRA_LINES_BEFORE = 0
|
PATCH_EXTRA_LINES_BEFORE = 0
|
||||||
PATCH_EXTRA_LINES_AFTER = 0
|
PATCH_EXTRA_LINES_AFTER = 0
|
||||||
else:
|
else:
|
||||||
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
||||||
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
||||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
|
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
|
||||||
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
|
PATCH_EXTRA_LINES_BEFORE, "before"
|
||||||
|
)
|
||||||
|
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(
|
||||||
|
PATCH_EXTRA_LINES_AFTER, "after"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
diff_files_original = git_provider.get_diff_files()
|
diff_files_original = git_provider.get_diff_files()
|
||||||
except RateLimitExceededException as e:
|
except RateLimitExceededException as e:
|
||||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
get_logger().error(
|
||||||
|
f"Rate limit exceeded for git provider API. original message {e}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
diff_files = filter_ignored(diff_files_original)
|
diff_files = filter_ignored(diff_files_original)
|
||||||
if diff_files != diff_files_original:
|
if diff_files != diff_files_original:
|
||||||
try:
|
try:
|
||||||
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
|
get_logger().info(
|
||||||
|
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
|
||||||
|
)
|
||||||
new_names = set([a.filename for a in diff_files])
|
new_names = set([a.filename for a in diff_files])
|
||||||
orig_names = set([a.filename for a in diff_files_original])
|
orig_names = set([a.filename for a in diff_files_original])
|
||||||
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# get pr languages
|
# get pr languages
|
||||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
pr_languages = sort_files_by_main_languages(
|
||||||
|
git_provider.get_languages(), diff_files
|
||||||
|
)
|
||||||
if pr_languages:
|
if pr_languages:
|
||||||
try:
|
try:
|
||||||
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
||||||
@ -76,24 +100,42 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
|||||||
|
|
||||||
# generate a standard diff string, with patch extension
|
# generate a standard diff string, with patch extension
|
||||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||||
pr_languages, token_handler, add_line_numbers_to_hunks,
|
pr_languages,
|
||||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
|
token_handler,
|
||||||
|
add_line_numbers_to_hunks,
|
||||||
|
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
|
||||||
|
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
|
||||||
|
)
|
||||||
|
|
||||||
# if we are under the limit, return the full diff
|
# if we are under the limit, return the full diff
|
||||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||||
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
|
get_logger().info(
|
||||||
f"returning full diff.")
|
f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
|
||||||
|
f"returning full diff."
|
||||||
|
)
|
||||||
return "\n".join(patches_extended)
|
return "\n".join(patches_extended)
|
||||||
|
|
||||||
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
|
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
|
||||||
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
get_logger().info(
|
||||||
f"pruning diff.")
|
f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
||||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
f"pruning diff."
|
||||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
|
)
|
||||||
|
(
|
||||||
|
patches_compressed_list,
|
||||||
|
total_tokens_list,
|
||||||
|
deleted_files_list,
|
||||||
|
remaining_files_list,
|
||||||
|
file_dict,
|
||||||
|
files_in_patches_list,
|
||||||
|
) = pr_generate_compressed_diff(
|
||||||
|
pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling
|
||||||
|
)
|
||||||
|
|
||||||
if large_pr_handling and len(patches_compressed_list) > 1:
|
if large_pr_handling and len(patches_compressed_list) > 1:
|
||||||
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
|
get_logger().info(
|
||||||
return "" # return empty string, as we want to generate multiple patches with a different prompt
|
f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff."
|
||||||
|
)
|
||||||
|
return "" # return empty string, as we want to generate multiple patches with a different prompt
|
||||||
|
|
||||||
# return the first patch
|
# return the first patch
|
||||||
patches_compressed = patches_compressed_list[0]
|
patches_compressed = patches_compressed_list[0]
|
||||||
@ -144,26 +186,37 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
|||||||
if deleted_list_str:
|
if deleted_list_str:
|
||||||
final_diff = final_diff + "\n\n" + deleted_list_str
|
final_diff = final_diff + "\n\n" + deleted_list_str
|
||||||
|
|
||||||
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
get_logger().debug(
|
||||||
f"deleted_list_str: {deleted_list_str}")
|
f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
||||||
|
f"deleted_list_str: {deleted_list_str}"
|
||||||
|
)
|
||||||
if not return_remaining_files:
|
if not return_remaining_files:
|
||||||
return final_diff
|
return final_diff
|
||||||
else:
|
else:
|
||||||
return final_diff, remaining_files_list
|
return final_diff, remaining_files_list
|
||||||
|
|
||||||
|
|
||||||
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
def get_pr_diff_multiple_patchs(
|
||||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False):
|
git_provider: GitProvider,
|
||||||
|
token_handler: TokenHandler,
|
||||||
|
model: str,
|
||||||
|
add_line_numbers_to_hunks: bool = False,
|
||||||
|
disable_extra_lines: bool = False,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
diff_files_original = git_provider.get_diff_files()
|
diff_files_original = git_provider.get_diff_files()
|
||||||
except RateLimitExceededException as e:
|
except RateLimitExceededException as e:
|
||||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
get_logger().error(
|
||||||
|
f"Rate limit exceeded for git provider API. original message {e}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
diff_files = filter_ignored(diff_files_original)
|
diff_files = filter_ignored(diff_files_original)
|
||||||
if diff_files != diff_files_original:
|
if diff_files != diff_files_original:
|
||||||
try:
|
try:
|
||||||
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
|
get_logger().info(
|
||||||
|
f"Filtered out {len(diff_files_original) - len(diff_files)} files"
|
||||||
|
)
|
||||||
new_names = set([a.filename for a in diff_files])
|
new_names = set([a.filename for a in diff_files])
|
||||||
orig_names = set([a.filename for a in diff_files_original])
|
orig_names = set([a.filename for a in diff_files_original])
|
||||||
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
||||||
@ -171,24 +224,47 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# get pr languages
|
# get pr languages
|
||||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
pr_languages = sort_files_by_main_languages(
|
||||||
|
git_provider.get_languages(), diff_files
|
||||||
|
)
|
||||||
if pr_languages:
|
if pr_languages:
|
||||||
try:
|
try:
|
||||||
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
(
|
||||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True)
|
patches_compressed_list,
|
||||||
|
total_tokens_list,
|
||||||
|
deleted_files_list,
|
||||||
|
remaining_files_list,
|
||||||
|
file_dict,
|
||||||
|
files_in_patches_list,
|
||||||
|
) = pr_generate_compressed_diff(
|
||||||
|
pr_languages,
|
||||||
|
token_handler,
|
||||||
|
model,
|
||||||
|
add_line_numbers_to_hunks,
|
||||||
|
large_pr_handling=True,
|
||||||
|
)
|
||||||
|
|
||||||
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
return (
|
||||||
|
patches_compressed_list,
|
||||||
|
total_tokens_list,
|
||||||
|
deleted_files_list,
|
||||||
|
remaining_files_list,
|
||||||
|
file_dict,
|
||||||
|
files_in_patches_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pr_generate_extended_diff(pr_languages: list,
|
def pr_generate_extended_diff(
|
||||||
token_handler: TokenHandler,
|
pr_languages: list,
|
||||||
add_line_numbers_to_hunks: bool,
|
token_handler: TokenHandler,
|
||||||
patch_extra_lines_before: int = 0,
|
add_line_numbers_to_hunks: bool,
|
||||||
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
|
patch_extra_lines_before: int = 0,
|
||||||
|
patch_extra_lines_after: int = 0,
|
||||||
|
) -> Tuple[list, int, list]:
|
||||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||||
patches_extended = []
|
patches_extended = []
|
||||||
patches_extended_tokens = []
|
patches_extended_tokens = []
|
||||||
@ -200,20 +276,33 @@ def pr_generate_extended_diff(pr_languages: list,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# extend each patch with extra lines of context
|
# extend each patch with extra lines of context
|
||||||
extended_patch = extend_patch(original_file_content_str, patch,
|
extended_patch = extend_patch(
|
||||||
patch_extra_lines_before, patch_extra_lines_after, file.filename)
|
original_file_content_str,
|
||||||
|
patch,
|
||||||
|
patch_extra_lines_before,
|
||||||
|
patch_extra_lines_after,
|
||||||
|
file.filename,
|
||||||
|
)
|
||||||
if not extended_patch:
|
if not extended_patch:
|
||||||
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
|
get_logger().warning(
|
||||||
|
f"Failed to extend patch for file: {file.filename}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if add_line_numbers_to_hunks:
|
if add_line_numbers_to_hunks:
|
||||||
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
|
full_extended_patch = convert_to_hunks_with_lines_numbers(
|
||||||
|
extended_patch, file
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
|
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
|
||||||
|
|
||||||
# add AI-summary metadata to the patch
|
# add AI-summary metadata to the patch
|
||||||
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
|
if file.ai_file_summary and get_settings().get(
|
||||||
full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch)
|
"config.enable_ai_metadata", False
|
||||||
|
):
|
||||||
|
full_extended_patch = add_ai_summary_top_patch(
|
||||||
|
file, full_extended_patch
|
||||||
|
)
|
||||||
|
|
||||||
patch_tokens = token_handler.count_tokens(full_extended_patch)
|
patch_tokens = token_handler.count_tokens(full_extended_patch)
|
||||||
file.tokens = patch_tokens
|
file.tokens = patch_tokens
|
||||||
@ -224,9 +313,13 @@ def pr_generate_extended_diff(pr_languages: list,
|
|||||||
return patches_extended, total_tokens, patches_extended_tokens
|
return patches_extended, total_tokens, patches_extended_tokens
|
||||||
|
|
||||||
|
|
||||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
def pr_generate_compressed_diff(
|
||||||
convert_hunks_to_line_numbers: bool,
|
top_langs: list,
|
||||||
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
|
token_handler: TokenHandler,
|
||||||
|
model: str,
|
||||||
|
convert_hunks_to_line_numbers: bool,
|
||||||
|
large_pr_handling: bool,
|
||||||
|
) -> Tuple[list, list, list, list, dict, list]:
|
||||||
deleted_files_list = []
|
deleted_files_list = []
|
||||||
|
|
||||||
# sort each one of the languages in top_langs by the number of tokens in the diff
|
# sort each one of the languages in top_langs by the number of tokens in the diff
|
||||||
@ -244,8 +337,13 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# removing delete-only hunks
|
# removing delete-only hunks
|
||||||
patch = handle_patch_deletions(patch, original_file_content_str,
|
patch = handle_patch_deletions(
|
||||||
new_file_content_str, file.filename, file.edit_type)
|
patch,
|
||||||
|
original_file_content_str,
|
||||||
|
new_file_content_str,
|
||||||
|
file.filename,
|
||||||
|
file.edit_type,
|
||||||
|
)
|
||||||
if patch is None:
|
if patch is None:
|
||||||
if file.filename not in deleted_files_list:
|
if file.filename not in deleted_files_list:
|
||||||
deleted_files_list.append(file.filename)
|
deleted_files_list.append(file.filename)
|
||||||
@ -259,30 +357,54 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
# patch = add_ai_summary_top_patch(file, patch)
|
# patch = add_ai_summary_top_patch(file, patch)
|
||||||
|
|
||||||
new_patch_tokens = token_handler.count_tokens(patch)
|
new_patch_tokens = token_handler.count_tokens(patch)
|
||||||
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type}
|
file_dict[file.filename] = {
|
||||||
|
'patch': patch,
|
||||||
|
'tokens': new_patch_tokens,
|
||||||
|
'edit_type': file.edit_type,
|
||||||
|
}
|
||||||
|
|
||||||
max_tokens_model = get_max_tokens(model)
|
max_tokens_model = get_max_tokens(model)
|
||||||
|
|
||||||
# first iteration
|
# first iteration
|
||||||
files_in_patches_list = []
|
files_in_patches_list = []
|
||||||
remaining_files_list = [file.filename for file in sorted_files]
|
remaining_files_list = [file.filename for file in sorted_files]
|
||||||
patches_list =[]
|
patches_list = []
|
||||||
total_tokens_list = []
|
total_tokens_list = []
|
||||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict,
|
(
|
||||||
max_tokens_model, remaining_files_list, token_handler)
|
total_tokens,
|
||||||
|
patches,
|
||||||
|
remaining_files_list,
|
||||||
|
files_in_patch_list,
|
||||||
|
) = generate_full_patch(
|
||||||
|
convert_hunks_to_line_numbers,
|
||||||
|
file_dict,
|
||||||
|
max_tokens_model,
|
||||||
|
remaining_files_list,
|
||||||
|
token_handler,
|
||||||
|
)
|
||||||
patches_list.append(patches)
|
patches_list.append(patches)
|
||||||
total_tokens_list.append(total_tokens)
|
total_tokens_list.append(total_tokens)
|
||||||
files_in_patches_list.append(files_in_patch_list)
|
files_in_patches_list.append(files_in_patch_list)
|
||||||
|
|
||||||
# additional iterations (if needed)
|
# additional iterations (if needed)
|
||||||
if large_pr_handling:
|
if large_pr_handling:
|
||||||
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
|
NUMBER_OF_ALLOWED_ITERATIONS = (
|
||||||
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
|
get_settings().pr_description.max_ai_calls - 1
|
||||||
|
) # one more call is to summarize
|
||||||
|
for i in range(NUMBER_OF_ALLOWED_ITERATIONS - 1):
|
||||||
if remaining_files_list:
|
if remaining_files_list:
|
||||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
|
(
|
||||||
file_dict,
|
total_tokens,
|
||||||
max_tokens_model,
|
patches,
|
||||||
remaining_files_list, token_handler)
|
remaining_files_list,
|
||||||
|
files_in_patch_list,
|
||||||
|
) = generate_full_patch(
|
||||||
|
convert_hunks_to_line_numbers,
|
||||||
|
file_dict,
|
||||||
|
max_tokens_model,
|
||||||
|
remaining_files_list,
|
||||||
|
token_handler,
|
||||||
|
)
|
||||||
if patches:
|
if patches:
|
||||||
patches_list.append(patches)
|
patches_list.append(patches)
|
||||||
total_tokens_list.append(total_tokens)
|
total_tokens_list.append(total_tokens)
|
||||||
@ -290,11 +412,24 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
return (
|
||||||
|
patches_list,
|
||||||
|
total_tokens_list,
|
||||||
|
deleted_files_list,
|
||||||
|
remaining_files_list,
|
||||||
|
file_dict,
|
||||||
|
files_in_patches_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler):
|
def generate_full_patch(
|
||||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
convert_hunks_to_line_numbers,
|
||||||
|
file_dict,
|
||||||
|
max_tokens_model,
|
||||||
|
remaining_files_list_prev,
|
||||||
|
token_handler,
|
||||||
|
):
|
||||||
|
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||||
patches = []
|
patches = []
|
||||||
remaining_files_list_new = []
|
remaining_files_list_new = []
|
||||||
files_in_patch_list = []
|
files_in_patch_list = []
|
||||||
@ -312,7 +447,10 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# If the patch is too large, just show the file name
|
# If the patch is too large, just show the file name
|
||||||
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
if (
|
||||||
|
total_tokens + new_patch_tokens
|
||||||
|
> max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||||
|
):
|
||||||
# Current logic is to skip the patch if it's too large
|
# Current logic is to skip the patch if it's too large
|
||||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||||
# until we meet the requirements
|
# until we meet the requirements
|
||||||
@ -334,7 +472,9 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
|
|||||||
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
|
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
|
||||||
|
|
||||||
|
|
||||||
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
|
async def retry_with_fallback_models(
|
||||||
|
f: Callable, model_type: ModelType = ModelType.REGULAR
|
||||||
|
):
|
||||||
all_models = _get_all_models(model_type)
|
all_models = _get_all_models(model_type)
|
||||||
all_deployments = _get_all_deployments(all_models)
|
all_deployments = _get_all_deployments(all_models)
|
||||||
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
||||||
@ -347,11 +487,11 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
|
|||||||
get_settings().set("openai.deployment_id", deployment_id)
|
get_settings().set("openai.deployment_id", deployment_id)
|
||||||
return await f(model)
|
return await f(model)
|
||||||
except:
|
except:
|
||||||
get_logger().warning(
|
get_logger().warning(f"Failed to generate prediction with {model}")
|
||||||
f"Failed to generate prediction with {model}"
|
|
||||||
)
|
|
||||||
if i == len(all_models) - 1: # If it's the last iteration
|
if i == len(all_models) - 1: # If it's the last iteration
|
||||||
raise Exception(f"Failed to generate prediction with any model of {all_models}")
|
raise Exception(
|
||||||
|
f"Failed to generate prediction with any model of {all_models}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
|
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
|
||||||
@ -374,17 +514,21 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
|
|||||||
if fallback_deployments:
|
if fallback_deployments:
|
||||||
all_deployments = [deployment_id] + fallback_deployments
|
all_deployments = [deployment_id] + fallback_deployments
|
||||||
if len(all_deployments) < len(all_models):
|
if len(all_deployments) < len(all_models):
|
||||||
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
|
raise ValueError(
|
||||||
f"is less than the number of models ({len(all_models)})")
|
f"The number of deployments ({len(all_deployments)}) "
|
||||||
|
f"is less than the number of models ({len(all_models)})"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
all_deployments = [deployment_id] * len(all_models)
|
all_deployments = [deployment_id] * len(all_models)
|
||||||
return all_deployments
|
return all_deployments
|
||||||
|
|
||||||
|
|
||||||
def get_pr_multi_diffs(git_provider: GitProvider,
|
def get_pr_multi_diffs(
|
||||||
token_handler: TokenHandler,
|
git_provider: GitProvider,
|
||||||
model: str,
|
token_handler: TokenHandler,
|
||||||
max_calls: int = 5) -> List[str]:
|
model: str,
|
||||||
|
max_calls: int = 5,
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
|
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
|
||||||
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
|
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
|
||||||
@ -404,13 +548,17 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
|||||||
try:
|
try:
|
||||||
diff_files = git_provider.get_diff_files()
|
diff_files = git_provider.get_diff_files()
|
||||||
except RateLimitExceededException as e:
|
except RateLimitExceededException as e:
|
||||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
get_logger().error(
|
||||||
|
f"Rate limit exceeded for git provider API. original message {e}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
diff_files = filter_ignored(diff_files)
|
diff_files = filter_ignored(diff_files)
|
||||||
|
|
||||||
# Sort files by main language
|
# Sort files by main language
|
||||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
pr_languages = sort_files_by_main_languages(
|
||||||
|
git_provider.get_languages(), diff_files
|
||||||
|
)
|
||||||
|
|
||||||
# Sort files within each language group by tokens in descending order
|
# Sort files within each language group by tokens in descending order
|
||||||
sorted_files = []
|
sorted_files = []
|
||||||
@ -420,14 +568,19 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
|||||||
# Get the maximum number of extra lines before and after the patch
|
# Get the maximum number of extra lines before and after the patch
|
||||||
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
||||||
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
||||||
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(PATCH_EXTRA_LINES_BEFORE, "before")
|
PATCH_EXTRA_LINES_BEFORE = cap_and_log_extra_lines(
|
||||||
|
PATCH_EXTRA_LINES_BEFORE, "before"
|
||||||
|
)
|
||||||
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
|
PATCH_EXTRA_LINES_AFTER = cap_and_log_extra_lines(PATCH_EXTRA_LINES_AFTER, "after")
|
||||||
|
|
||||||
# try first a single run with standard diff string, with patch extension, and no deletions
|
# try first a single run with standard diff string, with patch extension, and no deletions
|
||||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||||
pr_languages, token_handler, add_line_numbers_to_hunks=True,
|
pr_languages,
|
||||||
|
token_handler,
|
||||||
|
add_line_numbers_to_hunks=True,
|
||||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
|
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
|
||||||
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
|
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER,
|
||||||
|
)
|
||||||
|
|
||||||
# if we are under the limit, return the full diff
|
# if we are under the limit, return the full diff
|
||||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||||
@ -450,27 +603,50 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Remove delete-only hunks
|
# Remove delete-only hunks
|
||||||
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type)
|
patch = handle_patch_deletions(
|
||||||
|
patch,
|
||||||
|
original_file_content_str,
|
||||||
|
new_file_content_str,
|
||||||
|
file.filename,
|
||||||
|
file.edit_type,
|
||||||
|
)
|
||||||
if patch is None:
|
if patch is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
||||||
# add AI-summary metadata to the patch
|
# add AI-summary metadata to the patch
|
||||||
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
|
if file.ai_file_summary and get_settings().get(
|
||||||
|
"config.enable_ai_metadata", False
|
||||||
|
):
|
||||||
patch = add_ai_summary_top_patch(file, patch)
|
patch = add_ai_summary_top_patch(file, patch)
|
||||||
new_patch_tokens = token_handler.count_tokens(patch)
|
new_patch_tokens = token_handler.count_tokens(patch)
|
||||||
|
|
||||||
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
if (
|
||||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
patch
|
||||||
|
and (token_handler.prompt_tokens + new_patch_tokens)
|
||||||
|
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||||
|
):
|
||||||
if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
|
if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
|
||||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||||
continue
|
continue
|
||||||
elif get_settings().config.get('large_patch_policy') == 'clip':
|
elif get_settings().config.get('large_patch_policy') == 'clip':
|
||||||
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
|
delta_tokens = (
|
||||||
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
|
get_max_tokens(model)
|
||||||
|
- OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||||
|
- token_handler.prompt_tokens
|
||||||
|
)
|
||||||
|
patch_clipped = clip_tokens(
|
||||||
|
patch,
|
||||||
|
delta_tokens,
|
||||||
|
delete_last_line=True,
|
||||||
|
num_input_tokens=new_patch_tokens,
|
||||||
|
)
|
||||||
new_patch_tokens = token_handler.count_tokens(patch_clipped)
|
new_patch_tokens = token_handler.count_tokens(patch_clipped)
|
||||||
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
if (
|
||||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
patch_clipped
|
||||||
|
and (token_handler.prompt_tokens + new_patch_tokens)
|
||||||
|
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||||
|
):
|
||||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -480,13 +656,16 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
|||||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
|
if patch and (
|
||||||
|
total_tokens + new_patch_tokens
|
||||||
|
> get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD
|
||||||
|
):
|
||||||
final_diff = "\n".join(patches)
|
final_diff = "\n".join(patches)
|
||||||
final_diff_list.append(final_diff)
|
final_diff_list.append(final_diff)
|
||||||
patches = []
|
patches = []
|
||||||
total_tokens = token_handler.prompt_tokens
|
total_tokens = token_handler.prompt_tokens
|
||||||
call_number += 1
|
call_number += 1
|
||||||
if call_number > max_calls: # avoid creating new patches
|
if call_number > max_calls: # avoid creating new patches
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Reached max calls ({max_calls})")
|
get_logger().info(f"Reached max calls ({max_calls})")
|
||||||
break
|
break
|
||||||
@ -497,7 +676,9 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
|||||||
patches.append(patch)
|
patches.append(patch)
|
||||||
total_tokens += new_patch_tokens
|
total_tokens += new_patch_tokens
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
get_logger().info(
|
||||||
|
f"Tokens: {total_tokens}, last filename: {file.filename}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add the last chunk
|
# Add the last chunk
|
||||||
if patches:
|
if patches:
|
||||||
@ -515,7 +696,10 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
|
|||||||
if not pr_description_files:
|
if not pr_description_files:
|
||||||
get_logger().warning(f"PR description files are empty.")
|
get_logger().warning(f"PR description files are empty.")
|
||||||
return
|
return
|
||||||
available_files = {pr_file['full_file_name'].strip(): pr_file for pr_file in pr_description_files}
|
available_files = {
|
||||||
|
pr_file['full_file_name'].strip(): pr_file
|
||||||
|
for pr_file in pr_description_files
|
||||||
|
}
|
||||||
diff_files = git_provider.get_diff_files()
|
diff_files = git_provider.get_diff_files()
|
||||||
found_any_match = False
|
found_any_match = False
|
||||||
for file in diff_files:
|
for file in diff_files:
|
||||||
@ -524,11 +708,15 @@ def add_ai_metadata_to_diff_files(git_provider, pr_description_files):
|
|||||||
file.ai_file_summary = available_files[filename]
|
file.ai_file_summary = available_files[filename]
|
||||||
found_any_match = True
|
found_any_match = True
|
||||||
if not found_any_match:
|
if not found_any_match:
|
||||||
get_logger().error(f"Failed to find any matching files between PR description and diff files.",
|
get_logger().error(
|
||||||
artifact={"pr_description_files": pr_description_files})
|
f"Failed to find any matching files between PR description and diff files.",
|
||||||
|
artifact={"pr_description_files": pr_description_files},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to add AI metadata to diff files: {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Failed to add AI metadata to diff files: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_ai_summary_top_patch(file, full_extended_patch):
|
def add_ai_summary_top_patch(file, full_extended_patch):
|
||||||
@ -537,14 +725,18 @@ def add_ai_summary_top_patch(file, full_extended_patch):
|
|||||||
full_extended_patch_lines = full_extended_patch.split("\n")
|
full_extended_patch_lines = full_extended_patch.split("\n")
|
||||||
for i, line in enumerate(full_extended_patch_lines):
|
for i, line in enumerate(full_extended_patch_lines):
|
||||||
if line.startswith("## File:") or line.startswith("## file:"):
|
if line.startswith("## File:") or line.startswith("## file:"):
|
||||||
full_extended_patch_lines.insert(i + 1,
|
full_extended_patch_lines.insert(
|
||||||
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}")
|
i + 1,
|
||||||
|
f"### AI-generated changes summary:\n{file.ai_file_summary['long_summary']}",
|
||||||
|
)
|
||||||
full_extended_patch = "\n".join(full_extended_patch_lines)
|
full_extended_patch = "\n".join(full_extended_patch_lines)
|
||||||
return full_extended_patch
|
return full_extended_patch
|
||||||
|
|
||||||
# if no '## File: ...' was found
|
# if no '## File: ...' was found
|
||||||
return full_extended_patch
|
return full_extended_patch
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to add AI summary to the top of the patch: {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Failed to add AI summary to the top of the patch: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
return full_extended_patch
|
return full_extended_patch
|
||||||
|
|||||||
@ -15,12 +15,17 @@ class TokenEncoder:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_token_encoder(cls):
|
def get_token_encoder(cls):
|
||||||
model = get_settings().config.model
|
model = get_settings().config.model
|
||||||
if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance
|
if (
|
||||||
|
cls._encoder_instance is None or model != cls._model
|
||||||
|
): # Check without acquiring the lock for performance
|
||||||
with cls._lock: # Lock acquisition to ensure thread safety
|
with cls._lock: # Lock acquisition to ensure thread safety
|
||||||
if cls._encoder_instance is None or model != cls._model:
|
if cls._encoder_instance is None or model != cls._model:
|
||||||
cls._model = model
|
cls._model = model
|
||||||
cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding(
|
cls._encoder_instance = (
|
||||||
"cl100k_base")
|
encoding_for_model(cls._model)
|
||||||
|
if "gpt" in cls._model
|
||||||
|
else get_encoding("cl100k_base")
|
||||||
|
)
|
||||||
return cls._encoder_instance
|
return cls._encoder_instance
|
||||||
|
|
||||||
|
|
||||||
@ -49,7 +54,9 @@ class TokenHandler:
|
|||||||
"""
|
"""
|
||||||
self.encoder = TokenEncoder.get_token_encoder()
|
self.encoder = TokenEncoder.get_token_encoder()
|
||||||
if pr is not None:
|
if pr is not None:
|
||||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
self.prompt_tokens = self._get_system_user_tokens(
|
||||||
|
pr, self.encoder, vars, system, user
|
||||||
|
)
|
||||||
|
|
||||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -41,10 +41,12 @@ class Range(BaseModel):
|
|||||||
column_start: int = -1
|
column_start: int = -1
|
||||||
column_end: int = -1
|
column_end: int = -1
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
REGULAR = "regular"
|
REGULAR = "regular"
|
||||||
WEAK = "weak"
|
WEAK = "weak"
|
||||||
|
|
||||||
|
|
||||||
class PRReviewHeader(str, Enum):
|
class PRReviewHeader(str, Enum):
|
||||||
REGULAR = "## PR 评审指南"
|
REGULAR = "## PR 评审指南"
|
||||||
INCREMENTAL = "## 增量 PR 评审指南"
|
INCREMENTAL = "## 增量 PR 评审指南"
|
||||||
@ -57,7 +59,9 @@ class PRDescriptionHeader(str, Enum):
|
|||||||
def get_setting(key: str) -> Any:
|
def get_setting(key: str) -> Any:
|
||||||
try:
|
try:
|
||||||
key = key.upper()
|
key = key.upper()
|
||||||
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
|
return context.get("settings", global_settings).get(
|
||||||
|
key, global_settings.get(key, None)
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return global_settings.get(key, None)
|
return global_settings.get(key, None)
|
||||||
|
|
||||||
@ -72,14 +76,29 @@ def emphasize_header(text: str, only_markdown=False, reference_link=None) -> str
|
|||||||
# Everything before the colon (inclusive) is wrapped in <strong> tags
|
# Everything before the colon (inclusive) is wrapped in <strong> tags
|
||||||
if only_markdown:
|
if only_markdown:
|
||||||
if reference_link:
|
if reference_link:
|
||||||
transformed_string = f"[**{text[:colon_position + 1]}**]({reference_link})\n" + text[colon_position + 1:]
|
transformed_string = (
|
||||||
|
f"[**{text[:colon_position + 1]}**]({reference_link})\n"
|
||||||
|
+ text[colon_position + 1 :]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
transformed_string = f"**{text[:colon_position + 1]}**\n" + text[colon_position + 1:]
|
transformed_string = (
|
||||||
|
f"**{text[:colon_position + 1]}**\n"
|
||||||
|
+ text[colon_position + 1 :]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if reference_link:
|
if reference_link:
|
||||||
transformed_string = f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>" + text[colon_position + 1:]
|
transformed_string = (
|
||||||
|
f"<strong><a href='{reference_link}'>{text[:colon_position + 1]}</a></strong><br>"
|
||||||
|
+ text[colon_position + 1 :]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
transformed_string = "<strong>" + text[:colon_position + 1] + "</strong>" +'<br>' + text[colon_position + 1:]
|
transformed_string = (
|
||||||
|
"<strong>"
|
||||||
|
+ text[: colon_position + 1]
|
||||||
|
+ "</strong>"
|
||||||
|
+ '<br>'
|
||||||
|
+ text[colon_position + 1 :]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# If there's no ": ", return the original string
|
# If there's no ": ", return the original string
|
||||||
transformed_string = text
|
transformed_string = text
|
||||||
@ -101,11 +120,14 @@ def unique_strings(input_list: List[str]) -> List[str]:
|
|||||||
seen.add(item)
|
seen.add(item)
|
||||||
return unique_list
|
return unique_list
|
||||||
|
|
||||||
def convert_to_markdown_v2(output_data: dict,
|
|
||||||
gfm_supported: bool = True,
|
def convert_to_markdown_v2(
|
||||||
incremental_review=None,
|
output_data: dict,
|
||||||
git_provider=None,
|
gfm_supported: bool = True,
|
||||||
files=None) -> str:
|
incremental_review=None,
|
||||||
|
git_provider=None,
|
||||||
|
files=None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of data into markdown format.
|
Convert a dictionary of data into markdown format.
|
||||||
Args:
|
Args:
|
||||||
@ -183,7 +205,9 @@ def convert_to_markdown_v2(output_data: dict,
|
|||||||
else:
|
else:
|
||||||
markdown_text += f"### {emoji} PR 包含测试\n\n"
|
markdown_text += f"### {emoji} PR 包含测试\n\n"
|
||||||
elif 'ticket compliance check' in key_nice.lower():
|
elif 'ticket compliance check' in key_nice.lower():
|
||||||
markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported)
|
markdown_text = ticket_markdown_logic(
|
||||||
|
emoji, markdown_text, value, gfm_supported
|
||||||
|
)
|
||||||
elif 'security concerns' in key_nice.lower():
|
elif 'security concerns' in key_nice.lower():
|
||||||
if gfm_supported:
|
if gfm_supported:
|
||||||
markdown_text += f"<tr><td>"
|
markdown_text += f"<tr><td>"
|
||||||
@ -220,7 +244,9 @@ def convert_to_markdown_v2(output_data: dict,
|
|||||||
if gfm_supported:
|
if gfm_supported:
|
||||||
markdown_text += f"<tr><td>"
|
markdown_text += f"<tr><td>"
|
||||||
# markdown_text += f"{emoji} <strong>{key_nice}</strong><br><br>\n\n"
|
# markdown_text += f"{emoji} <strong>{key_nice}</strong><br><br>\n\n"
|
||||||
markdown_text += f"{emoji} <strong>建议评审的重点领域</strong><br><br>\n\n"
|
markdown_text += (
|
||||||
|
f"{emoji} <strong>建议评审的重点领域</strong><br><br>\n\n"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
markdown_text += f"### {emoji} 建议评审的重点领域\n\n#### \n"
|
markdown_text += f"### {emoji} 建议评审的重点领域\n\n#### \n"
|
||||||
for i, issue in enumerate(issues):
|
for i, issue in enumerate(issues):
|
||||||
@ -235,9 +261,13 @@ def convert_to_markdown_v2(output_data: dict,
|
|||||||
start_line = int(str(issue.get('start_line', 0)).strip())
|
start_line = int(str(issue.get('start_line', 0)).strip())
|
||||||
end_line = int(str(issue.get('end_line', 0)).strip())
|
end_line = int(str(issue.get('end_line', 0)).strip())
|
||||||
|
|
||||||
relevant_lines_str = extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=True)
|
relevant_lines_str = extract_relevant_lines_str(
|
||||||
|
end_line, files, relevant_file, start_line, dedent=True
|
||||||
|
)
|
||||||
if git_provider:
|
if git_provider:
|
||||||
reference_link = git_provider.get_line_link(relevant_file, start_line, end_line)
|
reference_link = git_provider.get_line_link(
|
||||||
|
relevant_file, start_line, end_line
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
reference_link = None
|
reference_link = None
|
||||||
|
|
||||||
@ -256,7 +286,9 @@ def convert_to_markdown_v2(output_data: dict,
|
|||||||
issue_str = f"**{issue_header}**\n\n{issue_content}\n\n"
|
issue_str = f"**{issue_header}**\n\n{issue_content}\n\n"
|
||||||
markdown_text += f"{issue_str}\n\n"
|
markdown_text += f"{issue_str}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to process 'Recommended focus areas for review': {e}")
|
get_logger().exception(
|
||||||
|
f"Failed to process 'Recommended focus areas for review': {e}"
|
||||||
|
)
|
||||||
if gfm_supported:
|
if gfm_supported:
|
||||||
markdown_text += f"</td></tr>\n"
|
markdown_text += f"</td></tr>\n"
|
||||||
else:
|
else:
|
||||||
@ -273,7 +305,9 @@ def convert_to_markdown_v2(output_data: dict,
|
|||||||
return markdown_text
|
return markdown_text
|
||||||
|
|
||||||
|
|
||||||
def extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=False) -> str:
|
def extract_relevant_lines_str(
|
||||||
|
end_line, files, relevant_file, start_line, dedent=False
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content.
|
Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content.
|
||||||
"""
|
"""
|
||||||
@ -286,10 +320,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
|
|||||||
if not file.head_file:
|
if not file.head_file:
|
||||||
# as a fallback, extract relevant lines directly from patch
|
# as a fallback, extract relevant lines directly from patch
|
||||||
patch = file.patch
|
patch = file.patch
|
||||||
get_logger().info(f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead")
|
get_logger().info(
|
||||||
_, selected_lines = extract_hunk_lines_from_patch(patch, file.filename, start_line, end_line,side='right')
|
f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead"
|
||||||
|
)
|
||||||
|
_, selected_lines = extract_hunk_lines_from_patch(
|
||||||
|
patch, file.filename, start_line, end_line, side='right'
|
||||||
|
)
|
||||||
if not selected_lines:
|
if not selected_lines:
|
||||||
get_logger().error(f"Failed to extract relevant lines from patch: {file.filename}")
|
get_logger().error(
|
||||||
|
f"Failed to extract relevant lines from patch: {file.filename}"
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
# filter out '-' lines
|
# filter out '-' lines
|
||||||
relevant_lines_str = ""
|
relevant_lines_str = ""
|
||||||
@ -299,12 +339,16 @@ def extract_relevant_lines_str(end_line, files, relevant_file, start_line, deden
|
|||||||
relevant_lines_str += line[1:] + '\n'
|
relevant_lines_str += line[1:] + '\n'
|
||||||
else:
|
else:
|
||||||
relevant_file_lines = file.head_file.splitlines()
|
relevant_file_lines = file.head_file.splitlines()
|
||||||
relevant_lines_str = "\n".join(relevant_file_lines[start_line - 1:end_line])
|
relevant_lines_str = "\n".join(
|
||||||
|
relevant_file_lines[start_line - 1 : end_line]
|
||||||
|
)
|
||||||
|
|
||||||
if dedent and relevant_lines_str:
|
if dedent and relevant_lines_str:
|
||||||
# Remove the longest leading string of spaces and tabs common to all lines.
|
# Remove the longest leading string of spaces and tabs common to all lines.
|
||||||
relevant_lines_str = textwrap.dedent(relevant_lines_str)
|
relevant_lines_str = textwrap.dedent(relevant_lines_str)
|
||||||
relevant_lines_str = f"```{file.language}\n{relevant_lines_str}\n```"
|
relevant_lines_str = (
|
||||||
|
f"```{file.language}\n{relevant_lines_str}\n```"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
return relevant_lines_str
|
return relevant_lines_str
|
||||||
@ -325,14 +369,21 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
|||||||
ticket_url = ticket_analysis.get('ticket_url', '').strip()
|
ticket_url = ticket_analysis.get('ticket_url', '').strip()
|
||||||
explanation = ''
|
explanation = ''
|
||||||
ticket_compliance_level = '' # Individual ticket compliance
|
ticket_compliance_level = '' # Individual ticket compliance
|
||||||
fully_compliant_str = ticket_analysis.get('fully_compliant_requirements', '').strip()
|
fully_compliant_str = ticket_analysis.get(
|
||||||
not_compliant_str = ticket_analysis.get('not_compliant_requirements', '').strip()
|
'fully_compliant_requirements', ''
|
||||||
requires_further_human_verification = ticket_analysis.get('requires_further_human_verification',
|
).strip()
|
||||||
'').strip()
|
not_compliant_str = ticket_analysis.get(
|
||||||
|
'not_compliant_requirements', ''
|
||||||
|
).strip()
|
||||||
|
requires_further_human_verification = ticket_analysis.get(
|
||||||
|
'requires_further_human_verification', ''
|
||||||
|
).strip()
|
||||||
|
|
||||||
if not fully_compliant_str and not not_compliant_str:
|
if not fully_compliant_str and not not_compliant_str:
|
||||||
get_logger().debug(f"Ticket compliance has no requirements",
|
get_logger().debug(
|
||||||
artifact={'ticket_url': ticket_url})
|
f"Ticket compliance has no requirements",
|
||||||
|
artifact={'ticket_url': ticket_url},
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Calculate individual ticket compliance level
|
# Calculate individual ticket compliance level
|
||||||
@ -353,19 +404,27 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
|||||||
|
|
||||||
# build compliance string
|
# build compliance string
|
||||||
if fully_compliant_str:
|
if fully_compliant_str:
|
||||||
explanation += f"Compliant requirements:\n\n{fully_compliant_str}\n\n"
|
explanation += (
|
||||||
|
f"Compliant requirements:\n\n{fully_compliant_str}\n\n"
|
||||||
|
)
|
||||||
if not_compliant_str:
|
if not_compliant_str:
|
||||||
explanation += f"Non-compliant requirements:\n\n{not_compliant_str}\n\n"
|
explanation += (
|
||||||
|
f"Non-compliant requirements:\n\n{not_compliant_str}\n\n"
|
||||||
|
)
|
||||||
if requires_further_human_verification:
|
if requires_further_human_verification:
|
||||||
explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n"
|
explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n"
|
||||||
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n"
|
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n"
|
||||||
|
|
||||||
# for debugging
|
# for debugging
|
||||||
if requires_further_human_verification:
|
if requires_further_human_verification:
|
||||||
get_logger().debug(f"Ticket compliance requires further human verification",
|
get_logger().debug(
|
||||||
artifact={'ticket_url': ticket_url,
|
f"Ticket compliance requires further human verification",
|
||||||
'requires_further_human_verification': requires_further_human_verification,
|
artifact={
|
||||||
'compliance_level': ticket_compliance_level})
|
'ticket_url': ticket_url,
|
||||||
|
'requires_further_human_verification': requires_further_human_verification,
|
||||||
|
'compliance_level': ticket_compliance_level,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to process ticket compliance: {e}")
|
get_logger().exception(f"Failed to process ticket compliance: {e}")
|
||||||
@ -381,7 +440,10 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
|||||||
compliance_emoji = '✅'
|
compliance_emoji = '✅'
|
||||||
elif any(level == 'Not compliant' for level in all_compliance_levels):
|
elif any(level == 'Not compliant' for level in all_compliance_levels):
|
||||||
# If there's a mix of compliant and non-compliant tickets
|
# If there's a mix of compliant and non-compliant tickets
|
||||||
if any(level in ['Fully compliant', 'PR Code Verified'] for level in all_compliance_levels):
|
if any(
|
||||||
|
level in ['Fully compliant', 'PR Code Verified']
|
||||||
|
for level in all_compliance_levels
|
||||||
|
):
|
||||||
compliance_level = 'Partially compliant'
|
compliance_level = 'Partially compliant'
|
||||||
compliance_emoji = '🔶'
|
compliance_emoji = '🔶'
|
||||||
else:
|
else:
|
||||||
@ -395,7 +457,9 @@ def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
|||||||
compliance_emoji = '✅'
|
compliance_emoji = '✅'
|
||||||
|
|
||||||
# Set extra statistics outside the ticket loop
|
# Set extra statistics outside the ticket loop
|
||||||
get_settings().set('config.extra_statistics', {'compliance_level': compliance_level})
|
get_settings().set(
|
||||||
|
'config.extra_statistics', {'compliance_level': compliance_level}
|
||||||
|
)
|
||||||
|
|
||||||
# editing table row for ticket compliance analysis
|
# editing table row for ticket compliance analysis
|
||||||
if gfm_supported:
|
if gfm_supported:
|
||||||
@ -425,7 +489,9 @@ def process_can_be_split(emoji, value):
|
|||||||
for i, split in enumerate(value):
|
for i, split in enumerate(value):
|
||||||
title = split.get('title', '')
|
title = split.get('title', '')
|
||||||
relevant_files = split.get('relevant_files', [])
|
relevant_files = split.get('relevant_files', [])
|
||||||
markdown_text += f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n"
|
markdown_text += (
|
||||||
|
f"<details><summary>\n子 PR 主题: <b>{title}</b></summary>\n\n"
|
||||||
|
)
|
||||||
markdown_text += f"___\n\n相关文件:\n\n"
|
markdown_text += f"___\n\n相关文件:\n\n"
|
||||||
for file in relevant_files:
|
for file in relevant_files:
|
||||||
markdown_text += f"- {file}\n"
|
markdown_text += f"- {file}\n"
|
||||||
@ -464,7 +530,9 @@ def process_can_be_split(emoji, value):
|
|||||||
return markdown_text
|
return markdown_text
|
||||||
|
|
||||||
|
|
||||||
def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool = True) -> str:
|
def parse_code_suggestion(
|
||||||
|
code_suggestion: dict, i: int = 0, gfm_supported: bool = True
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a dictionary of data into markdown format.
|
Convert a dictionary of data into markdown format.
|
||||||
|
|
||||||
@ -484,15 +552,19 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
|
|||||||
markdown_text += f"<tr><td>相关文件</td><td>{relevant_file}</td></tr>"
|
markdown_text += f"<tr><td>相关文件</td><td>{relevant_file}</td></tr>"
|
||||||
# continue
|
# continue
|
||||||
elif sub_key.lower() == 'suggestion':
|
elif sub_key.lower() == 'suggestion':
|
||||||
markdown_text += (f"<tr><td>{sub_key} </td>"
|
markdown_text += (
|
||||||
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>")
|
f"<tr><td>{sub_key} </td>"
|
||||||
|
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>"
|
||||||
|
)
|
||||||
elif sub_key.lower() == 'relevant_line':
|
elif sub_key.lower() == 'relevant_line':
|
||||||
markdown_text += f"<tr><td>相关行</td>"
|
markdown_text += f"<tr><td>相关行</td>"
|
||||||
sub_value_list = sub_value.split('](')
|
sub_value_list = sub_value.split('](')
|
||||||
relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
|
relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
|
||||||
if len(sub_value_list) > 1:
|
if len(sub_value_list) > 1:
|
||||||
link = sub_value_list[1].rstrip(')').strip('`')
|
link = sub_value_list[1].rstrip(')').strip('`')
|
||||||
markdown_text += f"<td><a href='{link}'>{relevant_line}</a></td>"
|
markdown_text += (
|
||||||
|
f"<td><a href='{link}'>{relevant_line}</a></td>"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
markdown_text += f"<td>{relevant_line}</td>"
|
markdown_text += f"<td>{relevant_line}</td>"
|
||||||
markdown_text += "</tr>"
|
markdown_text += "</tr>"
|
||||||
@ -505,11 +577,14 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
|
|||||||
for sub_key, sub_value in code_suggestion.items():
|
for sub_key, sub_value in code_suggestion.items():
|
||||||
if isinstance(sub_key, str):
|
if isinstance(sub_key, str):
|
||||||
sub_key = sub_key.rstrip()
|
sub_key = sub_key.rstrip()
|
||||||
if isinstance(sub_value,str):
|
if isinstance(sub_value, str):
|
||||||
sub_value = sub_value.rstrip()
|
sub_value = sub_value.rstrip()
|
||||||
if isinstance(sub_value, dict): # "code example"
|
if isinstance(sub_value, dict): # "code example"
|
||||||
markdown_text += f" - **{sub_key}:**\n"
|
markdown_text += f" - **{sub_key}:**\n"
|
||||||
for code_key, code_value in sub_value.items(): # 'before' and 'after' code
|
for (
|
||||||
|
code_key,
|
||||||
|
code_value,
|
||||||
|
) in sub_value.items(): # 'before' and 'after' code
|
||||||
code_str = f"```\n{code_value}\n```"
|
code_str = f"```\n{code_value}\n```"
|
||||||
code_str_indented = textwrap.indent(code_str, ' ')
|
code_str_indented = textwrap.indent(code_str, ' ')
|
||||||
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
|
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
|
||||||
@ -520,7 +595,9 @@ def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool
|
|||||||
markdown_text += f" **{sub_key}:** {sub_value} \n"
|
markdown_text += f" **{sub_key}:** {sub_value} \n"
|
||||||
if "relevant_line" not in sub_key.lower(): # nicer presentation
|
if "relevant_line" not in sub_key.lower(): # nicer presentation
|
||||||
# markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab
|
# markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab
|
||||||
markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker
|
markdown_text = (
|
||||||
|
markdown_text.rstrip('\n') + " \n"
|
||||||
|
) # works for gitlab and bitbucker
|
||||||
|
|
||||||
markdown_text += "\n"
|
markdown_text += "\n"
|
||||||
return markdown_text
|
return markdown_text
|
||||||
@ -561,9 +638,15 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
|
|||||||
else:
|
else:
|
||||||
closing_bracket = "]}}"
|
closing_bracket = "]}}"
|
||||||
|
|
||||||
if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \
|
if (
|
||||||
(review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) :
|
review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0
|
||||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
) or (
|
||||||
|
review.rfind("'Code suggestions': [") > 0
|
||||||
|
or review.rfind('"Code suggestions": [') > 0
|
||||||
|
):
|
||||||
|
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][
|
||||||
|
-1
|
||||||
|
] - 1
|
||||||
valid_json = False
|
valid_json = False
|
||||||
iter_count = 0
|
iter_count = 0
|
||||||
|
|
||||||
@ -574,7 +657,9 @@ def try_fix_json(review, max_iter=10, code_suggestions=False):
|
|||||||
review = review[:last_code_suggestion_ind].strip() + closing_bracket
|
review = review[:last_code_suggestion_ind].strip() + closing_bracket
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
review = review[:last_code_suggestion_ind]
|
review = review[:last_code_suggestion_ind]
|
||||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
last_code_suggestion_ind = [
|
||||||
|
m.end() for m in re.finditer(r"\}\s*,", review)
|
||||||
|
][-1] - 1
|
||||||
iter_count += 1
|
iter_count += 1
|
||||||
|
|
||||||
if not valid_json:
|
if not valid_json:
|
||||||
@ -629,7 +714,12 @@ def convert_str_to_datetime(date_str):
|
|||||||
return datetime.strptime(date_str, datetime_format)
|
return datetime.strptime(date_str, datetime_format)
|
||||||
|
|
||||||
|
|
||||||
def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str, show_warning: bool = True) -> str:
|
def load_large_diff(
|
||||||
|
filename,
|
||||||
|
new_file_content_str: str,
|
||||||
|
original_file_content_str: str,
|
||||||
|
show_warning: bool = True,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
|
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
|
||||||
input.
|
input.
|
||||||
@ -640,10 +730,14 @@ def load_large_diff(filename, new_file_content_str: str, original_file_content_s
|
|||||||
try:
|
try:
|
||||||
original_file_content_str = (original_file_content_str or "").rstrip() + "\n"
|
original_file_content_str = (original_file_content_str or "").rstrip() + "\n"
|
||||||
new_file_content_str = (new_file_content_str or "").rstrip() + "\n"
|
new_file_content_str = (new_file_content_str or "").rstrip() + "\n"
|
||||||
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
|
diff = difflib.unified_diff(
|
||||||
new_file_content_str.splitlines(keepends=True))
|
original_file_content_str.splitlines(keepends=True),
|
||||||
|
new_file_content_str.splitlines(keepends=True),
|
||||||
|
)
|
||||||
if get_settings().config.verbosity_level >= 2 and show_warning:
|
if get_settings().config.verbosity_level >= 2 and show_warning:
|
||||||
get_logger().info(f"File was modified, but no patch was found. Manually creating patch: {filename}.")
|
get_logger().info(
|
||||||
|
f"File was modified, but no patch was found. Manually creating patch: {filename}."
|
||||||
|
)
|
||||||
patch = ''.join(diff)
|
patch = ''.join(diff)
|
||||||
return patch
|
return patch
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -693,42 +787,68 @@ def _fix_key_value(key: str, value: str):
|
|||||||
try:
|
try:
|
||||||
value = yaml.safe_load(value)
|
value = yaml.safe_load(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().debug(f"Failed to parse YAML for config override {key}={value}", exc_info=e)
|
get_logger().debug(
|
||||||
|
f"Failed to parse YAML for config override {key}={value}", exc_info=e
|
||||||
|
)
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="") -> dict:
|
def load_yaml(
|
||||||
response_text = response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```')
|
response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key=""
|
||||||
|
) -> dict:
|
||||||
|
response_text = (
|
||||||
|
response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```')
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(response_text)
|
data = yaml.safe_load(response_text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Initial failure to parse AI prediction: {e}")
|
get_logger().warning(f"Initial failure to parse AI prediction: {e}")
|
||||||
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key)
|
data = try_fix_yaml(
|
||||||
|
response_text,
|
||||||
|
keys_fix_yaml=keys_fix_yaml,
|
||||||
|
first_key=first_key,
|
||||||
|
last_key=last_key,
|
||||||
|
)
|
||||||
if not data:
|
if not data:
|
||||||
get_logger().error(f"Failed to parse AI prediction after fallbacks",
|
get_logger().error(
|
||||||
artifact={'response_text': response_text})
|
f"Failed to parse AI prediction after fallbacks",
|
||||||
|
artifact={'response_text': response_text},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().info(f"Successfully parsed AI prediction after fallbacks",
|
get_logger().info(
|
||||||
artifact={'response_text': response_text})
|
f"Successfully parsed AI prediction after fallbacks",
|
||||||
|
artifact={'response_text': response_text},
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def try_fix_yaml(
|
||||||
def try_fix_yaml(response_text: str,
|
response_text: str,
|
||||||
keys_fix_yaml: List[str] = [],
|
keys_fix_yaml: List[str] = [],
|
||||||
first_key="",
|
first_key="",
|
||||||
last_key="",) -> dict:
|
last_key="",
|
||||||
|
) -> dict:
|
||||||
response_text_lines = response_text.split('\n')
|
response_text_lines = response_text.split('\n')
|
||||||
|
|
||||||
keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:']
|
keys_yaml = [
|
||||||
|
'relevant line:',
|
||||||
|
'suggestion content:',
|
||||||
|
'relevant file:',
|
||||||
|
'existing code:',
|
||||||
|
'improved code:',
|
||||||
|
]
|
||||||
keys_yaml = keys_yaml + keys_fix_yaml
|
keys_yaml = keys_yaml + keys_fix_yaml
|
||||||
# first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...'
|
# first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...'
|
||||||
response_text_lines_copy = response_text_lines.copy()
|
response_text_lines_copy = response_text_lines.copy()
|
||||||
for i in range(0, len(response_text_lines_copy)):
|
for i in range(0, len(response_text_lines_copy)):
|
||||||
for key in keys_yaml:
|
for key in keys_yaml:
|
||||||
if key in response_text_lines_copy[i] and not '|' in response_text_lines_copy[i]:
|
if (
|
||||||
response_text_lines_copy[i] = response_text_lines_copy[i].replace(f'{key}',
|
key in response_text_lines_copy[i]
|
||||||
f'{key} |\n ')
|
and not '|' in response_text_lines_copy[i]
|
||||||
|
):
|
||||||
|
response_text_lines_copy[i] = response_text_lines_copy[i].replace(
|
||||||
|
f'{key}', f'{key} |\n '
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load('\n'.join(response_text_lines_copy))
|
data = yaml.safe_load('\n'.join(response_text_lines_copy))
|
||||||
get_logger().info(f"Successfully parsed AI prediction after adding |-\n")
|
get_logger().info(f"Successfully parsed AI prediction after adding |-\n")
|
||||||
@ -743,22 +863,26 @@ def try_fix_yaml(response_text: str,
|
|||||||
snippet_text = snippet.group()
|
snippet_text = snippet.group()
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`'))
|
data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`'))
|
||||||
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
|
get_logger().info(
|
||||||
|
f"Successfully parsed AI prediction after extracting yaml snippet"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# third fallback - try to remove leading and trailing curly brackets
|
# third fallback - try to remove leading and trailing curly brackets
|
||||||
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
response_text_copy = (
|
||||||
|
response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(response_text_copy)
|
data = yaml.safe_load(response_text_copy)
|
||||||
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
|
get_logger().info(
|
||||||
|
f"Successfully parsed AI prediction after removing curly brackets"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# forth fallback - try to extract yaml snippet by 'first_key' and 'last_key'
|
# forth fallback - try to extract yaml snippet by 'first_key' and 'last_key'
|
||||||
# note that 'last_key' can be in practice a key that is not the last key in the yaml snippet.
|
# note that 'last_key' can be in practice a key that is not the last key in the yaml snippet.
|
||||||
# it just needs to be some inner key, so we can look for newlines after it
|
# it just needs to be some inner key, so we can look for newlines after it
|
||||||
@ -767,13 +891,23 @@ def try_fix_yaml(response_text: str,
|
|||||||
if index_start == -1:
|
if index_start == -1:
|
||||||
index_start = response_text.find(f"{first_key}:")
|
index_start = response_text.find(f"{first_key}:")
|
||||||
index_last_code = response_text.rfind(f"{last_key}:")
|
index_last_code = response_text.rfind(f"{last_key}:")
|
||||||
index_end = response_text.find("\n\n", index_last_code) # look for newlines after last_key
|
index_end = response_text.find(
|
||||||
|
"\n\n", index_last_code
|
||||||
|
) # look for newlines after last_key
|
||||||
if index_end == -1:
|
if index_end == -1:
|
||||||
index_end = len(response_text)
|
index_end = len(response_text)
|
||||||
response_text_copy = response_text[index_start:index_end].strip().strip('```yaml').strip('`').strip()
|
response_text_copy = (
|
||||||
|
response_text[index_start:index_end]
|
||||||
|
.strip()
|
||||||
|
.strip('```yaml')
|
||||||
|
.strip('`')
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(response_text_copy)
|
data = yaml.safe_load(response_text_copy)
|
||||||
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
|
get_logger().info(
|
||||||
|
f"Successfully parsed AI prediction after extracting yaml snippet"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -784,7 +918,9 @@ def try_fix_yaml(response_text: str,
|
|||||||
response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:]
|
response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:]
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load('\n'.join(response_text_lines_copy))
|
data = yaml.safe_load('\n'.join(response_text_lines_copy))
|
||||||
get_logger().info(f"Successfully parsed AI prediction after removing leading '+'")
|
get_logger().info(
|
||||||
|
f"Successfully parsed AI prediction after removing leading '+'"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -794,7 +930,9 @@ def try_fix_yaml(response_text: str,
|
|||||||
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
|
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(response_text_lines_tmp)
|
data = yaml.safe_load(response_text_lines_tmp)
|
||||||
get_logger().info(f"Successfully parsed AI prediction after removing {i} lines")
|
get_logger().info(
|
||||||
|
f"Successfully parsed AI prediction after removing {i} lines"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -820,11 +958,14 @@ def set_custom_labels(variables, git_provider=None):
|
|||||||
for k, v in labels.items():
|
for k, v in labels.items():
|
||||||
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
|
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
|
||||||
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
|
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
|
||||||
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
|
variables[
|
||||||
|
"custom_labels_class"
|
||||||
|
] += f"\n {k.lower().replace(' ', '_')} = {description}"
|
||||||
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
|
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
|
||||||
counter += 1
|
counter += 1
|
||||||
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
|
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
|
||||||
|
|
||||||
|
|
||||||
def get_user_labels(current_labels: List[str] = None):
|
def get_user_labels(current_labels: List[str] = None):
|
||||||
"""
|
"""
|
||||||
Only keep labels that has been added by the user
|
Only keep labels that has been added by the user
|
||||||
@ -866,14 +1007,22 @@ def get_max_tokens(model):
|
|||||||
elif settings.config.custom_model_max_tokens > 0:
|
elif settings.config.custom_model_max_tokens > 0:
|
||||||
max_tokens_model = settings.config.custom_model_max_tokens
|
max_tokens_model = settings.config.custom_model_max_tokens
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens")
|
raise Exception(
|
||||||
|
f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
if settings.config.max_model_tokens and settings.config.max_model_tokens > 0:
|
if settings.config.max_model_tokens and settings.config.max_model_tokens > 0:
|
||||||
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
|
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
|
||||||
return max_tokens_model
|
return max_tokens_model
|
||||||
|
|
||||||
|
|
||||||
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str:
|
def clip_tokens(
|
||||||
|
text: str,
|
||||||
|
max_tokens: int,
|
||||||
|
add_three_dots=True,
|
||||||
|
num_input_tokens=None,
|
||||||
|
delete_last_line=False,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Clip the number of tokens in a string to a maximum number of tokens.
|
Clip the number of tokens in a string to a maximum number of tokens.
|
||||||
|
|
||||||
@ -909,14 +1058,15 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_token
|
|||||||
clipped_text = clipped_text.rsplit('\n', 1)[0]
|
clipped_text = clipped_text.rsplit('\n', 1)[0]
|
||||||
if add_three_dots:
|
if add_three_dots:
|
||||||
clipped_text += "\n...(truncated)"
|
clipped_text += "\n...(truncated)"
|
||||||
else: # if the text is empty
|
else: # if the text is empty
|
||||||
clipped_text = ""
|
clipped_text = ""
|
||||||
|
|
||||||
return clipped_text
|
return clipped_text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to clip tokens: {e}")
|
get_logger().warning(f"Failed to clip tokens: {e}")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def replace_code_tags(text):
|
def replace_code_tags(text):
|
||||||
"""
|
"""
|
||||||
Replace odd instances of ` with <code> and even instances of ` with </code>
|
Replace odd instances of ` with <code> and even instances of ` with </code>
|
||||||
@ -928,15 +1078,16 @@ def replace_code_tags(text):
|
|||||||
return ''.join(parts)
|
return ''.join(parts)
|
||||||
|
|
||||||
|
|
||||||
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
def find_line_number_of_relevant_line_in_file(
|
||||||
relevant_file: str,
|
diff_files: List[FilePatchInfo],
|
||||||
relevant_line_in_file: str,
|
relevant_file: str,
|
||||||
absolute_position: int = None) -> Tuple[int, int]:
|
relevant_line_in_file: str,
|
||||||
|
absolute_position: int = None,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
position = -1
|
position = -1
|
||||||
if absolute_position is None:
|
if absolute_position is None:
|
||||||
absolute_position = -1
|
absolute_position = -1
|
||||||
re_hunk_header = re.compile(
|
re_hunk_header = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
|
||||||
|
|
||||||
if not diff_files:
|
if not diff_files:
|
||||||
return position, absolute_position
|
return position, absolute_position
|
||||||
@ -947,7 +1098,7 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
|||||||
patch_lines = patch.splitlines()
|
patch_lines = patch.splitlines()
|
||||||
delta = 0
|
delta = 0
|
||||||
start1, size1, start2, size2 = 0, 0, 0, 0
|
start1, size1, start2, size2 = 0, 0, 0, 0
|
||||||
if absolute_position != -1: # matching absolute to relative
|
if absolute_position != -1: # matching absolute to relative
|
||||||
for i, line in enumerate(patch_lines):
|
for i, line in enumerate(patch_lines):
|
||||||
# new hunk
|
# new hunk
|
||||||
if line.startswith('@@'):
|
if line.startswith('@@'):
|
||||||
@ -965,12 +1116,12 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# try to find the line in the patch using difflib, with some margin of error
|
# try to find the line in the patch using difflib, with some margin of error
|
||||||
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
matches_difflib: list[str | Any] = difflib.get_close_matches(
|
||||||
patch_lines, n=3, cutoff=0.93)
|
relevant_line_in_file, patch_lines, n=3, cutoff=0.93
|
||||||
|
)
|
||||||
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
||||||
relevant_line_in_file = matches_difflib[0]
|
relevant_line_in_file = matches_difflib[0]
|
||||||
|
|
||||||
|
|
||||||
for i, line in enumerate(patch_lines):
|
for i, line in enumerate(patch_lines):
|
||||||
if line.startswith('@@'):
|
if line.startswith('@@'):
|
||||||
delta = 0
|
delta = 0
|
||||||
@ -1002,19 +1153,26 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
|||||||
break
|
break
|
||||||
return position, absolute_position
|
return position, absolute_position
|
||||||
|
|
||||||
|
|
||||||
def get_rate_limit_status(github_token) -> dict:
|
def get_rate_limit_status(github_token) -> dict:
|
||||||
GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
|
GITHUB_API_URL = (
|
||||||
|
get_settings(use_context=False)
|
||||||
|
.get("GITHUB.BASE_URL", "https://api.github.com")
|
||||||
|
.rstrip("/")
|
||||||
|
) # "https://api.github.com"
|
||||||
# GITHUB_API_URL = "https://api.github.com"
|
# GITHUB_API_URL = "https://api.github.com"
|
||||||
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
|
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
|
||||||
HEADERS = {
|
HEADERS = {
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
"Authorization": f"token {github_token}"
|
"Authorization": f"token {github_token}",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
|
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
|
||||||
try:
|
try:
|
||||||
rate_limit_info = response.json()
|
rate_limit_info = response.json()
|
||||||
if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise
|
if (
|
||||||
|
rate_limit_info.get('message') == 'Rate limiting is not enabled.'
|
||||||
|
): # for github enterprise
|
||||||
return {'resources': {}}
|
return {'resources': {}}
|
||||||
response.raise_for_status() # Check for HTTP errors
|
response.raise_for_status() # Check for HTTP errors
|
||||||
except: # retry
|
except: # retry
|
||||||
@ -1024,12 +1182,16 @@ def get_rate_limit_status(github_token) -> dict:
|
|||||||
return rate_limit_info
|
return rate_limit_info
|
||||||
|
|
||||||
|
|
||||||
def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool:
|
def validate_rate_limit_github(
|
||||||
|
github_token, installation_id=None, threshold=0.1
|
||||||
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
rate_limit_status = get_rate_limit_status(github_token)
|
rate_limit_status = get_rate_limit_status(github_token)
|
||||||
if installation_id:
|
if installation_id:
|
||||||
get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}")
|
get_logger().debug(
|
||||||
# validate that the rate limit is not exceeded
|
f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}"
|
||||||
|
)
|
||||||
|
# validate that the rate limit is not exceeded
|
||||||
# validate that the rate limit is not exceeded
|
# validate that the rate limit is not exceeded
|
||||||
for key, value in rate_limit_status['resources'].items():
|
for key, value in rate_limit_status['resources'].items():
|
||||||
if value['remaining'] < value['limit'] * threshold:
|
if value['remaining'] < value['limit'] * threshold:
|
||||||
@ -1037,8 +1199,9 @@ def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error in rate limit {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Error in rate limit {e}", artifact={"traceback": traceback.format_exc()}
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -1051,7 +1214,9 @@ def validate_and_await_rate_limit(github_token):
|
|||||||
get_logger().error(f"key: {key}, value: {value}")
|
get_logger().error(f"key: {key}, value: {value}")
|
||||||
sleep_time_sec = value['reset'] - datetime.now().timestamp()
|
sleep_time_sec = value['reset'] - datetime.now().timestamp()
|
||||||
sleep_time_hour = sleep_time_sec / 3600.0
|
sleep_time_hour = sleep_time_sec / 3600.0
|
||||||
get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
|
get_logger().error(
|
||||||
|
f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours"
|
||||||
|
)
|
||||||
if sleep_time_sec > 0:
|
if sleep_time_sec > 0:
|
||||||
time.sleep(sleep_time_sec + 1)
|
time.sleep(sleep_time_sec + 1)
|
||||||
rate_limit_status = get_rate_limit_status(github_token)
|
rate_limit_status = get_rate_limit_status(github_token)
|
||||||
@ -1068,22 +1233,39 @@ def github_action_output(output_data: dict, key_name: str):
|
|||||||
|
|
||||||
key_data = output_data.get(key_name, {})
|
key_data = output_data.get(key_name, {})
|
||||||
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
|
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
|
||||||
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)
|
print(
|
||||||
|
f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}",
|
||||||
|
file=fh,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to write to GitHub Action output: {e}")
|
get_logger().error(f"Failed to write to GitHub Action output: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def show_relevant_configurations(relevant_section: str) -> str:
|
def show_relevant_configurations(relevant_section: str) -> str:
|
||||||
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
|
skip_keys = [
|
||||||
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME']
|
'ai_disclaimer',
|
||||||
|
'ai_disclaimer_title',
|
||||||
|
'ANALYTICS_FOLDER',
|
||||||
|
'secret_provider',
|
||||||
|
"skip_keys",
|
||||||
|
"app_id",
|
||||||
|
"redirect",
|
||||||
|
'trial_prefix_message',
|
||||||
|
'no_eligible_message',
|
||||||
|
'identity_provider',
|
||||||
|
'ALLOWED_REPOS',
|
||||||
|
'APP_NAME',
|
||||||
|
]
|
||||||
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
|
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
|
||||||
if extra_skip_keys:
|
if extra_skip_keys:
|
||||||
skip_keys.extend(extra_skip_keys)
|
skip_keys.extend(extra_skip_keys)
|
||||||
|
|
||||||
markdown_text = ""
|
markdown_text = ""
|
||||||
markdown_text += "\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n"
|
markdown_text += (
|
||||||
markdown_text +="<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n"
|
"\n<hr>\n<details> <summary><strong>🛠️ 相关配置:</strong></summary> \n\n"
|
||||||
|
)
|
||||||
|
markdown_text += "<br>以下是相关工具地配置 [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml):\n\n"
|
||||||
markdown_text += f"**[config**]\n```yaml\n\n"
|
markdown_text += f"**[config**]\n```yaml\n\n"
|
||||||
for key, value in get_settings().config.items():
|
for key, value in get_settings().config.items():
|
||||||
if key in skip_keys:
|
if key in skip_keys:
|
||||||
@ -1099,6 +1281,7 @@ def show_relevant_configurations(relevant_section: str) -> str:
|
|||||||
markdown_text += "\n</details>\n"
|
markdown_text += "\n</details>\n"
|
||||||
return markdown_text
|
return markdown_text
|
||||||
|
|
||||||
|
|
||||||
def is_value_no(value):
|
def is_value_no(value):
|
||||||
if not value:
|
if not value:
|
||||||
return True
|
return True
|
||||||
@ -1122,7 +1305,7 @@ def string_to_uniform_number(s: str) -> float:
|
|||||||
# Convert the hash to an integer
|
# Convert the hash to an integer
|
||||||
hash_int = int(hash_object.hexdigest(), 16)
|
hash_int = int(hash_object.hexdigest(), 16)
|
||||||
# Normalize the integer to the range [0, 1]
|
# Normalize the integer to the range [0, 1]
|
||||||
max_hash_int = 2 ** 256 - 1
|
max_hash_int = 2**256 - 1
|
||||||
uniform_number = float(hash_int) / max_hash_int
|
uniform_number = float(hash_int) / max_hash_int
|
||||||
return uniform_number
|
return uniform_number
|
||||||
|
|
||||||
@ -1131,7 +1314,9 @@ def process_description(description_full: str) -> Tuple[str, List]:
|
|||||||
if not description_full:
|
if not description_full:
|
||||||
return "", []
|
return "", []
|
||||||
|
|
||||||
description_split = description_full.split(PRDescriptionHeader.CHANGES_WALKTHROUGH.value)
|
description_split = description_full.split(
|
||||||
|
PRDescriptionHeader.CHANGES_WALKTHROUGH.value
|
||||||
|
)
|
||||||
base_description_str = description_split[0]
|
base_description_str = description_split[0]
|
||||||
changes_walkthrough_str = ""
|
changes_walkthrough_str = ""
|
||||||
files = []
|
files = []
|
||||||
@ -1167,45 +1352,58 @@ def process_description(description_full: str) -> Tuple[str, List]:
|
|||||||
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong><dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\n\n\s*(.*?)</details>'
|
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong><dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\n\n\s*(.*?)</details>'
|
||||||
res = re.search(pattern_back, file_data, re.DOTALL)
|
res = re.search(pattern_back, file_data, re.DOTALL)
|
||||||
if not res or res.lastindex != 4:
|
if not res or res.lastindex != 4:
|
||||||
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong>\s*<dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\s*-\s*(.*?)\s*</details>' # looking for hypen ('- ')
|
pattern_back = r'<details>\s*<summary><strong>(.*?)</strong>\s*<dd><code>(.*?)</code>.*?</summary>\s*<hr>\s*(.*?)\s*-\s*(.*?)\s*</details>' # looking for hypen ('- ')
|
||||||
res = re.search(pattern_back, file_data, re.DOTALL)
|
res = re.search(pattern_back, file_data, re.DOTALL)
|
||||||
if res and res.lastindex == 4:
|
if res and res.lastindex == 4:
|
||||||
short_filename = res.group(1).strip()
|
short_filename = res.group(1).strip()
|
||||||
short_summary = res.group(2).strip()
|
short_summary = res.group(2).strip()
|
||||||
long_filename = res.group(3).strip()
|
long_filename = res.group(3).strip()
|
||||||
long_summary = res.group(4).strip()
|
long_summary = res.group(4).strip()
|
||||||
long_summary = long_summary.replace('<br> *', '\n*').replace('<br>','').replace('\n','<br>')
|
long_summary = (
|
||||||
|
long_summary.replace('<br> *', '\n*')
|
||||||
|
.replace('<br>', '')
|
||||||
|
.replace('\n', '<br>')
|
||||||
|
)
|
||||||
long_summary = h.handle(long_summary).strip()
|
long_summary = h.handle(long_summary).strip()
|
||||||
if long_summary.startswith('\\-'):
|
if long_summary.startswith('\\-'):
|
||||||
long_summary = "* " + long_summary[2:]
|
long_summary = "* " + long_summary[2:]
|
||||||
elif not long_summary.startswith('*'):
|
elif not long_summary.startswith('*'):
|
||||||
long_summary = f"* {long_summary}"
|
long_summary = f"* {long_summary}"
|
||||||
|
|
||||||
files.append({
|
files.append(
|
||||||
'short_file_name': short_filename,
|
{
|
||||||
'full_file_name': long_filename,
|
'short_file_name': short_filename,
|
||||||
'short_summary': short_summary,
|
'full_file_name': long_filename,
|
||||||
'long_summary': long_summary
|
'short_summary': short_summary,
|
||||||
})
|
'long_summary': long_summary,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if '<code>...</code>' in file_data:
|
if '<code>...</code>' in file_data:
|
||||||
pass # PR with many files. some did not get analyzed
|
pass # PR with many files. some did not get analyzed
|
||||||
else:
|
else:
|
||||||
get_logger().error(f"Failed to parse description", artifact={'description': file_data})
|
get_logger().error(
|
||||||
|
f"Failed to parse description",
|
||||||
|
artifact={'description': file_data},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data})
|
get_logger().exception(
|
||||||
|
f"Failed to process description: {e}",
|
||||||
|
artifact={'description': file_data},
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to process description: {e}")
|
get_logger().exception(f"Failed to process description: {e}")
|
||||||
|
|
||||||
return base_description_str, files
|
return base_description_str, files
|
||||||
|
|
||||||
|
|
||||||
def get_version() -> str:
|
def get_version() -> str:
|
||||||
# First check pyproject.toml if running directly out of repository
|
# First check pyproject.toml if running directly out of repository
|
||||||
if os.path.exists("pyproject.toml"):
|
if os.path.exists("pyproject.toml"):
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
with open("pyproject.toml", "rb") as f:
|
with open("pyproject.toml", "rb") as f:
|
||||||
data = tomllib.load(f)
|
data = tomllib.load(f)
|
||||||
if "project" in data and "version" in data["project"]:
|
if "project" in data and "version" in data["project"]:
|
||||||
@ -1213,7 +1411,9 @@ def get_version() -> str:
|
|||||||
else:
|
else:
|
||||||
get_logger().warning("Version not found in pyproject.toml")
|
get_logger().warning("Version not found in pyproject.toml")
|
||||||
else:
|
else:
|
||||||
get_logger().warning("Unable to determine local version from pyproject.toml")
|
get_logger().warning(
|
||||||
|
"Unable to determine local version from pyproject.toml"
|
||||||
|
)
|
||||||
|
|
||||||
# Otherwise get the installed pip package version
|
# Otherwise get the installed pip package version
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -12,8 +12,9 @@ setup_logger(log_level)
|
|||||||
|
|
||||||
|
|
||||||
def set_parser():
|
def set_parser():
|
||||||
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
|
parser = argparse.ArgumentParser(
|
||||||
"""\
|
description='AI based pull request analyzer',
|
||||||
|
usage="""\
|
||||||
Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>].
|
Usage: cli.py --pr-url=<URL on supported git hosting service> <command> [<args>].
|
||||||
For example:
|
For example:
|
||||||
- cli.py --pr_url=... review
|
- cli.py --pr_url=... review
|
||||||
@ -45,11 +46,20 @@ def set_parser():
|
|||||||
Configuration:
|
Configuration:
|
||||||
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
|
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
|
||||||
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
||||||
""")
|
""",
|
||||||
parser.add_argument('--version', action='version', version=f'pr-agent {get_version()}')
|
)
|
||||||
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', default=None)
|
parser.add_argument(
|
||||||
parser.add_argument('--issue_url', type=str, help='The URL of the Issue to review', default=None)
|
'--version', action='version', version=f'pr-agent {get_version()}'
|
||||||
parser.add_argument('command', type=str, help='The', choices=commands, default='review')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--pr_url', type=str, help='The URL of the PR to review', default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--issue_url', type=str, help='The URL of the Issue to review', default=None
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'command', type=str, help='The', choices=commands, default='review'
|
||||||
|
)
|
||||||
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -76,14 +86,24 @@ def run(inargs=None, args=None):
|
|||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
if args.issue_url:
|
if args.issue_url:
|
||||||
result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest))
|
result = await asyncio.create_task(
|
||||||
|
PRAgent().handle_request(args.issue_url, [command] + args.rest)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest))
|
result = await asyncio.create_task(
|
||||||
|
PRAgent().handle_request(args.pr_url, [command] + args.rest)
|
||||||
|
)
|
||||||
|
|
||||||
if get_settings().litellm.get("enable_callbacks", False):
|
if get_settings().litellm.get("enable_callbacks", False):
|
||||||
# There may be additional events on the event queue from the run above. If there are give them time to complete.
|
# There may be additional events on the event queue from the run above. If there are give them time to complete.
|
||||||
get_logger().debug("Waiting for event queue to complete")
|
get_logger().debug("Waiting for event queue to complete")
|
||||||
await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()])
|
await asyncio.wait(
|
||||||
|
[
|
||||||
|
task
|
||||||
|
for task in asyncio.all_tasks()
|
||||||
|
if task is not asyncio.current_task()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,9 @@ def main():
|
|||||||
provider = "github" # GitHub provider
|
provider = "github" # GitHub provider
|
||||||
user_token = "..." # GitHub user token
|
user_token = "..." # GitHub user token
|
||||||
openai_key = "..." # OpenAI key
|
openai_key = "..." # OpenAI key
|
||||||
pr_url = "..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
|
pr_url = (
|
||||||
|
"..." # PR URL, for example 'https://github.com/Codium-ai/pr-agent/pull/809'
|
||||||
|
)
|
||||||
command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"')
|
command = "/review" # Command to run (e.g. '/review', '/describe', '/ask="What is the purpose of this PR?"')
|
||||||
|
|
||||||
# Setting the configurations
|
# Setting the configurations
|
||||||
|
|||||||
@ -11,26 +11,29 @@ current_dir = dirname(abspath(__file__))
|
|||||||
global_settings = Dynaconf(
|
global_settings = Dynaconf(
|
||||||
envvar_prefix=False,
|
envvar_prefix=False,
|
||||||
merge_enabled=True,
|
merge_enabled=True,
|
||||||
settings_files=[join(current_dir, f) for f in [
|
settings_files=[
|
||||||
"settings/configuration.toml",
|
join(current_dir, f)
|
||||||
"settings/ignore.toml",
|
for f in [
|
||||||
"settings/language_extensions.toml",
|
"settings/configuration.toml",
|
||||||
"settings/pr_reviewer_prompts.toml",
|
"settings/ignore.toml",
|
||||||
"settings/pr_questions_prompts.toml",
|
"settings/language_extensions.toml",
|
||||||
"settings/pr_line_questions_prompts.toml",
|
"settings/pr_reviewer_prompts.toml",
|
||||||
"settings/pr_description_prompts.toml",
|
"settings/pr_questions_prompts.toml",
|
||||||
"settings/pr_code_suggestions_prompts.toml",
|
"settings/pr_line_questions_prompts.toml",
|
||||||
"settings/pr_code_suggestions_reflect_prompts.toml",
|
"settings/pr_description_prompts.toml",
|
||||||
"settings/pr_sort_code_suggestions_prompts.toml",
|
"settings/pr_code_suggestions_prompts.toml",
|
||||||
"settings/pr_information_from_user_prompts.toml",
|
"settings/pr_code_suggestions_reflect_prompts.toml",
|
||||||
"settings/pr_update_changelog_prompts.toml",
|
"settings/pr_sort_code_suggestions_prompts.toml",
|
||||||
"settings/pr_custom_labels.toml",
|
"settings/pr_information_from_user_prompts.toml",
|
||||||
"settings/pr_add_docs.toml",
|
"settings/pr_update_changelog_prompts.toml",
|
||||||
"settings/custom_labels.toml",
|
"settings/pr_custom_labels.toml",
|
||||||
"settings/pr_help_prompts.toml",
|
"settings/pr_add_docs.toml",
|
||||||
"settings/.secrets.toml",
|
"settings/custom_labels.toml",
|
||||||
"settings_prod/.secrets.toml",
|
"settings/pr_help_prompts.toml",
|
||||||
]]
|
"settings/.secrets.toml",
|
||||||
|
"settings_prod/.secrets.toml",
|
||||||
|
]
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,9 @@ from starlette_context import context
|
|||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
from utils.pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
||||||
from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
from utils.pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
||||||
from utils.pr_agent.git_providers.bitbucket_server_provider import \
|
from utils.pr_agent.git_providers.bitbucket_server_provider import (
|
||||||
BitbucketServerProvider
|
BitbucketServerProvider,
|
||||||
|
)
|
||||||
from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
from utils.pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
||||||
from utils.pr_agent.git_providers.gerrit_provider import GerritProvider
|
from utils.pr_agent.git_providers.gerrit_provider import GerritProvider
|
||||||
from utils.pr_agent.git_providers.git_provider import GitProvider
|
from utils.pr_agent.git_providers.git_provider import GitProvider
|
||||||
@ -28,7 +29,9 @@ def get_git_provider():
|
|||||||
try:
|
try:
|
||||||
provider_id = get_settings().config.git_provider
|
provider_id = get_settings().config.git_provider
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("git_provider is a required attribute in the configuration file") from e
|
raise ValueError(
|
||||||
|
"git_provider is a required attribute in the configuration file"
|
||||||
|
) from e
|
||||||
if provider_id not in _GIT_PROVIDERS:
|
if provider_id not in _GIT_PROVIDERS:
|
||||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||||
return _GIT_PROVIDERS[provider_id]
|
return _GIT_PROVIDERS[provider_id]
|
||||||
|
|||||||
@ -6,25 +6,33 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
|||||||
|
|
||||||
from ..algo.file_filter import filter_ignored
|
from ..algo.file_filter import filter_ignored
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
from ..algo.utils import (PRDescriptionHeader, find_line_number_of_relevant_line_in_file,
|
from ..algo.utils import (
|
||||||
load_large_diff)
|
PRDescriptionHeader,
|
||||||
|
find_line_number_of_relevant_line_in_file,
|
||||||
|
load_large_diff,
|
||||||
|
)
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from ..log import get_logger
|
from ..log import get_logger
|
||||||
from .git_provider import GitProvider
|
from .git_provider import GitProvider
|
||||||
|
|
||||||
AZURE_DEVOPS_AVAILABLE = True
|
AZURE_DEVOPS_AVAILABLE = True
|
||||||
ADO_APP_CLIENT_DEFAULT_ID = "499b84ac-1321-427f-aa17-267ca6975798/.default"
|
ADO_APP_CLIENT_DEFAULT_ID = "499b84ac-1321-427f-aa17-267ca6975798/.default"
|
||||||
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000-1
|
MAX_PR_DESCRIPTION_AZURE_LENGTH = 4000 - 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from azure.devops.connection import Connection
|
from azure.devops.connection import Connection
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from azure.devops.v7_1.git.models import (Comment, CommentThread,
|
from azure.devops.v7_1.git.models import (
|
||||||
GitPullRequest,
|
Comment,
|
||||||
GitPullRequestIterationChanges,
|
CommentThread,
|
||||||
GitVersionDescriptor)
|
GitPullRequest,
|
||||||
|
GitPullRequestIterationChanges,
|
||||||
|
GitVersionDescriptor,
|
||||||
|
)
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from azure.identity import DefaultAzureCredential
|
from azure.identity import DefaultAzureCredential
|
||||||
from msrest.authentication import BasicAuthentication
|
from msrest.authentication import BasicAuthentication
|
||||||
@ -33,9 +41,8 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
class AzureDevopsProvider(GitProvider):
|
class AzureDevopsProvider(GitProvider):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
||||||
):
|
):
|
||||||
if not AZURE_DEVOPS_AVAILABLE:
|
if not AZURE_DEVOPS_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@ -67,13 +74,16 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
|
|
||||||
if not relevant_lines_start or relevant_lines_start == -1:
|
if not relevant_lines_start or relevant_lines_start == -1:
|
||||||
get_logger().warning(
|
get_logger().warning(
|
||||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
|
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if relevant_lines_end < relevant_lines_start:
|
if relevant_lines_end < relevant_lines_start:
|
||||||
get_logger().warning(f"Failed to publish code suggestion, "
|
get_logger().warning(
|
||||||
f"relevant_lines_end is {relevant_lines_end} and "
|
f"Failed to publish code suggestion, "
|
||||||
f"relevant_lines_start is {relevant_lines_start}")
|
f"relevant_lines_end is {relevant_lines_end} and "
|
||||||
|
f"relevant_lines_start is {relevant_lines_start}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if relevant_lines_end > relevant_lines_start:
|
if relevant_lines_end > relevant_lines_start:
|
||||||
@ -98,30 +108,32 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
for post_parameters in post_parameters_list:
|
for post_parameters in post_parameters_list:
|
||||||
try:
|
try:
|
||||||
comment = Comment(content=post_parameters["body"], comment_type=1)
|
comment = Comment(content=post_parameters["body"], comment_type=1)
|
||||||
thread = CommentThread(comments=[comment],
|
thread = CommentThread(
|
||||||
thread_context={
|
comments=[comment],
|
||||||
"filePath": post_parameters["path"],
|
thread_context={
|
||||||
"rightFileStart": {
|
"filePath": post_parameters["path"],
|
||||||
"line": post_parameters["start_line"],
|
"rightFileStart": {
|
||||||
"offset": 1,
|
"line": post_parameters["start_line"],
|
||||||
},
|
"offset": 1,
|
||||||
"rightFileEnd": {
|
},
|
||||||
"line": post_parameters["line"],
|
"rightFileEnd": {
|
||||||
"offset": 1,
|
"line": post_parameters["line"],
|
||||||
},
|
"offset": 1,
|
||||||
})
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
self.azure_devops_client.create_thread(
|
self.azure_devops_client.create_thread(
|
||||||
comment_thread=thread,
|
comment_thread=thread,
|
||||||
project=self.workspace_slug,
|
project=self.workspace_slug,
|
||||||
repository_id=self.repo_slug,
|
repository_id=self.repo_slug,
|
||||||
pull_request_id=self.pr_num
|
pull_request_id=self.pr_num,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Azure failed to publish code suggestion, error: {e}")
|
get_logger().warning(
|
||||||
|
f"Azure failed to publish code suggestion, error: {e}"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_pr_description_full(self) -> str:
|
def get_pr_description_full(self) -> str:
|
||||||
return self.pr.description
|
return self.pr.description
|
||||||
|
|
||||||
@ -204,9 +216,9 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
def get_files(self):
|
def get_files(self):
|
||||||
files = []
|
files = []
|
||||||
for i in self.azure_devops_client.get_pull_request_commits(
|
for i in self.azure_devops_client.get_pull_request_commits(
|
||||||
project=self.workspace_slug,
|
project=self.workspace_slug,
|
||||||
repository_id=self.repo_slug,
|
repository_id=self.repo_slug,
|
||||||
pull_request_id=self.pr_num,
|
pull_request_id=self.pr_num,
|
||||||
):
|
):
|
||||||
changes_obj = self.azure_devops_client.get_changes(
|
changes_obj = self.azure_devops_client.get_changes(
|
||||||
project=self.workspace_slug,
|
project=self.workspace_slug,
|
||||||
@ -220,7 +232,6 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
|
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if self.diff_files:
|
if self.diff_files:
|
||||||
return self.diff_files
|
return self.diff_files
|
||||||
|
|
||||||
@ -231,18 +242,20 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
iterations = self.azure_devops_client.get_pull_request_iterations(
|
iterations = self.azure_devops_client.get_pull_request_iterations(
|
||||||
repository_id=self.repo_slug,
|
repository_id=self.repo_slug,
|
||||||
pull_request_id=self.pr_num,
|
pull_request_id=self.pr_num,
|
||||||
project=self.workspace_slug
|
project=self.workspace_slug,
|
||||||
)
|
)
|
||||||
changes = None
|
changes = None
|
||||||
if iterations:
|
if iterations:
|
||||||
iteration_id = iterations[-1].id # Get the last iteration (most recent changes)
|
iteration_id = iterations[
|
||||||
|
-1
|
||||||
|
].id # Get the last iteration (most recent changes)
|
||||||
|
|
||||||
# Get changes for the iteration
|
# Get changes for the iteration
|
||||||
changes = self.azure_devops_client.get_pull_request_iteration_changes(
|
changes = self.azure_devops_client.get_pull_request_iteration_changes(
|
||||||
repository_id=self.repo_slug,
|
repository_id=self.repo_slug,
|
||||||
pull_request_id=self.pr_num,
|
pull_request_id=self.pr_num,
|
||||||
iteration_id=iteration_id,
|
iteration_id=iteration_id,
|
||||||
project=self.workspace_slug
|
project=self.workspace_slug,
|
||||||
)
|
)
|
||||||
diff_files = []
|
diff_files = []
|
||||||
diffs = []
|
diffs = []
|
||||||
@ -253,7 +266,9 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
path = item.get('path', None)
|
path = item.get('path', None)
|
||||||
if path:
|
if path:
|
||||||
diffs.append(path)
|
diffs.append(path)
|
||||||
diff_types[path] = change.additional_properties.get('changeType', 'Unknown')
|
diff_types[path] = change.additional_properties.get(
|
||||||
|
'changeType', 'Unknown'
|
||||||
|
)
|
||||||
|
|
||||||
# wrong implementation - gets all the files that were changed in any commit in the PR
|
# wrong implementation - gets all the files that were changed in any commit in the PR
|
||||||
# commits = self.azure_devops_client.get_pull_request_commits(
|
# commits = self.azure_devops_client.get_pull_request_commits(
|
||||||
@ -284,9 +299,13 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
diffs = filter_ignored(diffs_original, 'azure')
|
diffs = filter_ignored(diffs_original, 'azure')
|
||||||
if diffs_original != diffs:
|
if diffs_original != diffs:
|
||||||
try:
|
try:
|
||||||
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
|
get_logger().info(
|
||||||
{"files": diffs_original, # diffs is just a list of names
|
f"Filtered out [ignore] files for pull request:",
|
||||||
"filtered_files": diffs})
|
extra={
|
||||||
|
"files": diffs_original, # diffs is just a list of names
|
||||||
|
"filtered_files": diffs,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -311,7 +330,10 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
|
|
||||||
new_file_content_str = new_file_content_str.content
|
new_file_content_str = new_file_content_str.content
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
get_logger().error(f"Failed to retrieve new file content of {file} at version {version}", error=error)
|
get_logger().error(
|
||||||
|
f"Failed to retrieve new file content of {file} at version {version}",
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
# get_logger().error(
|
# get_logger().error(
|
||||||
# "Failed to retrieve new file content of %s at version %s. Error: %s",
|
# "Failed to retrieve new file content of %s at version %s. Error: %s",
|
||||||
# file,
|
# file,
|
||||||
@ -325,7 +347,9 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
edit_type = EDIT_TYPE.ADDED
|
edit_type = EDIT_TYPE.ADDED
|
||||||
elif diff_types[file] == "delete":
|
elif diff_types[file] == "delete":
|
||||||
edit_type = EDIT_TYPE.DELETED
|
edit_type = EDIT_TYPE.DELETED
|
||||||
elif "rename" in diff_types[file]: # diff_type can be `rename` | `edit, rename`
|
elif (
|
||||||
|
"rename" in diff_types[file]
|
||||||
|
): # diff_type can be `rename` | `edit, rename`
|
||||||
edit_type = EDIT_TYPE.RENAMED
|
edit_type = EDIT_TYPE.RENAMED
|
||||||
|
|
||||||
version = GitVersionDescriptor(
|
version = GitVersionDescriptor(
|
||||||
@ -345,17 +369,27 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
)
|
)
|
||||||
original_file_content_str = original_file_content_str.content
|
original_file_content_str = original_file_content_str.content
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
get_logger().error(f"Failed to retrieve original file content of {file} at version {version}", error=error)
|
get_logger().error(
|
||||||
|
f"Failed to retrieve original file content of {file} at version {version}",
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
|
|
||||||
patch = load_large_diff(
|
patch = load_large_diff(
|
||||||
file, new_file_content_str, original_file_content_str, show_warning=False
|
file,
|
||||||
|
new_file_content_str,
|
||||||
|
original_file_content_str,
|
||||||
|
show_warning=False,
|
||||||
).rstrip()
|
).rstrip()
|
||||||
|
|
||||||
# count number of lines added and removed
|
# count number of lines added and removed
|
||||||
patch_lines = patch.splitlines(keepends=True)
|
patch_lines = patch.splitlines(keepends=True)
|
||||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
num_plus_lines = len(
|
||||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
[line for line in patch_lines if line.startswith('+')]
|
||||||
|
)
|
||||||
|
num_minus_lines = len(
|
||||||
|
[line for line in patch_lines if line.startswith('-')]
|
||||||
|
)
|
||||||
|
|
||||||
diff_files.append(
|
diff_files.append(
|
||||||
FilePatchInfo(
|
FilePatchInfo(
|
||||||
@ -376,27 +410,35 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
get_logger().exception(f"Failed to get diff files, error: {e}")
|
get_logger().exception(f"Failed to get diff files, error: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False, thread_context=None):
|
def publish_comment(
|
||||||
|
self, pr_comment: str, is_temporary: bool = False, thread_context=None
|
||||||
|
):
|
||||||
if is_temporary and not get_settings().config.publish_output_progress:
|
if is_temporary and not get_settings().config.publish_output_progress:
|
||||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
get_logger().debug(
|
||||||
|
f"Skipping publish_comment for temporary comment: {pr_comment}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
comment = Comment(content=pr_comment)
|
comment = Comment(content=pr_comment)
|
||||||
thread = CommentThread(comments=[comment], thread_context=thread_context, status=5)
|
thread = CommentThread(
|
||||||
|
comments=[comment], thread_context=thread_context, status=5
|
||||||
|
)
|
||||||
thread_response = self.azure_devops_client.create_thread(
|
thread_response = self.azure_devops_client.create_thread(
|
||||||
comment_thread=thread,
|
comment_thread=thread,
|
||||||
project=self.workspace_slug,
|
project=self.workspace_slug,
|
||||||
repository_id=self.repo_slug,
|
repository_id=self.repo_slug,
|
||||||
pull_request_id=self.pr_num,
|
pull_request_id=self.pr_num,
|
||||||
)
|
)
|
||||||
response = {"thread_id": thread_response.id, "comment_id": thread_response.comments[0].id}
|
response = {
|
||||||
|
"thread_id": thread_response.id,
|
||||||
|
"comment_id": thread_response.comments[0].id,
|
||||||
|
}
|
||||||
if is_temporary:
|
if is_temporary:
|
||||||
self.temp_comments.append(response)
|
self.temp_comments.append(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def publish_description(self, pr_title: str, pr_body: str):
|
def publish_description(self, pr_title: str, pr_body: str):
|
||||||
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
||||||
|
usage_guide_text = '<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
|
||||||
usage_guide_text='<details> <summary><strong>✨ Describe tool usage guide:</strong></summary><hr>'
|
|
||||||
ind = pr_body.find(usage_guide_text)
|
ind = pr_body.find(usage_guide_text)
|
||||||
if ind != -1:
|
if ind != -1:
|
||||||
pr_body = pr_body[:ind]
|
pr_body = pr_body[:ind]
|
||||||
@ -409,7 +451,10 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
|
|
||||||
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
if len(pr_body) > MAX_PR_DESCRIPTION_AZURE_LENGTH:
|
||||||
trunction_message = " ... (description truncated due to length limit)"
|
trunction_message = " ... (description truncated due to length limit)"
|
||||||
pr_body = pr_body[:MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)] + trunction_message
|
pr_body = (
|
||||||
|
pr_body[: MAX_PR_DESCRIPTION_AZURE_LENGTH - len(trunction_message)]
|
||||||
|
+ trunction_message
|
||||||
|
)
|
||||||
get_logger().warning("PR description was truncated due to length limit")
|
get_logger().warning("PR description was truncated due to length limit")
|
||||||
try:
|
try:
|
||||||
updated_pr = GitPullRequest()
|
updated_pr = GitPullRequest()
|
||||||
@ -433,50 +478,79 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to remove temp comments, error: {e}")
|
get_logger().exception(f"Failed to remove temp comments, error: {e}")
|
||||||
|
|
||||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
original_suggestion=None,
|
||||||
|
):
|
||||||
|
self.publish_inline_comments(
|
||||||
|
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_inline_comment(
|
||||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
self,
|
||||||
absolute_position: int = None):
|
body: str,
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
|
relevant_file: str,
|
||||||
relevant_file.strip('`'),
|
relevant_line_in_file: str,
|
||||||
relevant_line_in_file,
|
absolute_position: int = None,
|
||||||
absolute_position)
|
):
|
||||||
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
|
self.get_diff_files(),
|
||||||
|
relevant_file.strip('`'),
|
||||||
|
relevant_line_in_file,
|
||||||
|
absolute_position,
|
||||||
|
)
|
||||||
if position == -1:
|
if position == -1:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
get_logger().info(
|
||||||
|
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||||
|
)
|
||||||
subject_type = "FILE"
|
subject_type = "FILE"
|
||||||
else:
|
else:
|
||||||
subject_type = "LINE"
|
subject_type = "LINE"
|
||||||
path = relevant_file.strip()
|
path = relevant_file.strip()
|
||||||
return dict(body=body, path=path, position=position, absolute_position=absolute_position) if subject_type == "LINE" else {}
|
return (
|
||||||
|
dict(
|
||||||
|
body=body,
|
||||||
|
path=path,
|
||||||
|
position=position,
|
||||||
|
absolute_position=absolute_position,
|
||||||
|
)
|
||||||
|
if subject_type == "LINE"
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
|
def publish_inline_comments(
|
||||||
overall_success = True
|
self, comments: list[dict], disable_fallback: bool = False
|
||||||
for comment in comments:
|
):
|
||||||
try:
|
overall_success = True
|
||||||
self.publish_comment(comment["body"],
|
for comment in comments:
|
||||||
thread_context={
|
try:
|
||||||
"filePath": comment["path"],
|
self.publish_comment(
|
||||||
"rightFileStart": {
|
comment["body"],
|
||||||
"line": comment["absolute_position"],
|
thread_context={
|
||||||
"offset": comment["position"],
|
"filePath": comment["path"],
|
||||||
},
|
"rightFileStart": {
|
||||||
"rightFileEnd": {
|
"line": comment["absolute_position"],
|
||||||
"line": comment["absolute_position"],
|
"offset": comment["position"],
|
||||||
"offset": comment["position"],
|
},
|
||||||
},
|
"rightFileEnd": {
|
||||||
})
|
"line": comment["absolute_position"],
|
||||||
if get_settings().config.verbosity_level >= 2:
|
"offset": comment["position"],
|
||||||
get_logger().info(
|
},
|
||||||
f"Published code suggestion on {self.pr_num} at {comment['path']}"
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
get_logger().info(
|
||||||
get_logger().error(f"Failed to publish code suggestion, error: {e}")
|
f"Published code suggestion on {self.pr_num} at {comment['path']}"
|
||||||
overall_success = False
|
)
|
||||||
return overall_success
|
except Exception as e:
|
||||||
|
if get_settings().config.verbosity_level >= 2:
|
||||||
|
get_logger().error(f"Failed to publish code suggestion, error: {e}")
|
||||||
|
overall_success = False
|
||||||
|
return overall_success
|
||||||
|
|
||||||
def get_title(self):
|
def get_title(self):
|
||||||
return self.pr.title
|
return self.pr.title
|
||||||
@ -521,7 +595,11 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
threads = self.azure_devops_client.get_threads(repository_id=self.repo_slug, pull_request_id=self.pr_num, project=self.workspace_slug)
|
threads = self.azure_devops_client.get_threads(
|
||||||
|
repository_id=self.repo_slug,
|
||||||
|
pull_request_id=self.pr_num,
|
||||||
|
project=self.workspace_slug,
|
||||||
|
)
|
||||||
threads.reverse()
|
threads.reverse()
|
||||||
comment_list = []
|
comment_list = []
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
@ -532,7 +610,9 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
comment_list.append(comment)
|
comment_list.append(comment)
|
||||||
return comment_list
|
return comment_list
|
||||||
|
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
def add_eyes_reaction(
|
||||||
|
self, issue_comment_id: int, disable_eyes: bool = False
|
||||||
|
) -> Optional[int]:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||||
@ -547,16 +627,22 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The provided URL does not appear to be a Azure DevOps PR URL"
|
"The provided URL does not appear to be a Azure DevOps PR URL"
|
||||||
)
|
)
|
||||||
if len(path_parts) == 6: # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
|
if (
|
||||||
|
len(path_parts) == 6
|
||||||
|
): # "https://dev.azure.com/organization/project/_git/repo/pullrequest/1"
|
||||||
workspace_slug = path_parts[1]
|
workspace_slug = path_parts[1]
|
||||||
repo_slug = path_parts[3]
|
repo_slug = path_parts[3]
|
||||||
pr_number = int(path_parts[5])
|
pr_number = int(path_parts[5])
|
||||||
elif len(path_parts) == 5: # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
|
elif (
|
||||||
|
len(path_parts) == 5
|
||||||
|
): # 'https://organization.visualstudio.com/project/_git/repo/pullrequest/1'
|
||||||
workspace_slug = path_parts[0]
|
workspace_slug = path_parts[0]
|
||||||
repo_slug = path_parts[2]
|
repo_slug = path_parts[2]
|
||||||
pr_number = int(path_parts[4])
|
pr_number = int(path_parts[4])
|
||||||
else:
|
else:
|
||||||
raise ValueError("The provided URL does not appear to be a Azure DevOps PR URL")
|
raise ValueError(
|
||||||
|
"The provided URL does not appear to be a Azure DevOps PR URL"
|
||||||
|
)
|
||||||
|
|
||||||
return workspace_slug, repo_slug, pr_number
|
return workspace_slug, repo_slug, pr_number
|
||||||
|
|
||||||
@ -575,12 +661,16 @@ class AzureDevopsProvider(GitProvider):
|
|||||||
# try to use azure default credentials
|
# try to use azure default credentials
|
||||||
# see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
|
# see https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
|
||||||
# for usage and env var configuration of user-assigned managed identity, local machine auth etc.
|
# for usage and env var configuration of user-assigned managed identity, local machine auth etc.
|
||||||
get_logger().info("No PAT found in settings, trying to use Azure Default Credentials.")
|
get_logger().info(
|
||||||
|
"No PAT found in settings, trying to use Azure Default Credentials."
|
||||||
|
)
|
||||||
credentials = DefaultAzureCredential()
|
credentials = DefaultAzureCredential()
|
||||||
accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID)
|
accessToken = credentials.get_token(ADO_APP_CLIENT_DEFAULT_ID)
|
||||||
auth_token = accessToken.token
|
auth_token = accessToken.token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"No PAT found in settings, and Azure Default Authentication failed, error: {e}")
|
get_logger().error(
|
||||||
|
f"No PAT found in settings, and Azure Default Authentication failed, error: {e}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
credentials = BasicAuthentication("", auth_token)
|
credentials = BasicAuthentication("", auth_token)
|
||||||
|
|||||||
@ -52,13 +52,19 @@ class BitbucketProvider(GitProvider):
|
|||||||
self.git_files = None
|
self.git_files = None
|
||||||
if pr_url:
|
if pr_url:
|
||||||
self.set_pr(pr_url)
|
self.set_pr(pr_url)
|
||||||
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"]["comments"]["href"]
|
self.bitbucket_comment_api_url = self.pr._BitbucketBase__data["links"][
|
||||||
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"]['self']['href']
|
"comments"
|
||||||
|
]["href"]
|
||||||
|
self.bitbucket_pull_request_api_url = self.pr._BitbucketBase__data["links"][
|
||||||
|
'self'
|
||||||
|
]['href']
|
||||||
|
|
||||||
def get_repo_settings(self):
|
def get_repo_settings(self):
|
||||||
try:
|
try:
|
||||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
url = (
|
||||||
f"{self.pr.destination_branch}/.pr_agent.toml")
|
f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||||
|
f"{self.pr.destination_branch}/.pr_agent.toml"
|
||||||
|
)
|
||||||
response = requests.request("GET", url, headers=self.headers)
|
response = requests.request("GET", url, headers=self.headers)
|
||||||
if response.status_code == 404: # not found
|
if response.status_code == 404: # not found
|
||||||
return ""
|
return ""
|
||||||
@ -74,20 +80,27 @@ class BitbucketProvider(GitProvider):
|
|||||||
post_parameters_list = []
|
post_parameters_list = []
|
||||||
for suggestion in code_suggestions:
|
for suggestion in code_suggestions:
|
||||||
body = suggestion["body"]
|
body = suggestion["body"]
|
||||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
original_suggestion = suggestion.get(
|
||||||
|
'original_suggestion', None
|
||||||
|
) # needed for diff code
|
||||||
if original_suggestion:
|
if original_suggestion:
|
||||||
try:
|
try:
|
||||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
||||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
||||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
diff = difflib.unified_diff(
|
||||||
improved_code.split('\n'), n=999)
|
existing_code.split('\n'), improved_code.split('\n'), n=999
|
||||||
|
)
|
||||||
patch_orig = "\n".join(diff)
|
patch_orig = "\n".join(diff)
|
||||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
body = re.sub(
|
||||||
|
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
|
get_logger().exception(
|
||||||
|
f"Bitbucket failed to get diff code for publishing, error: {e}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
relevant_file = suggestion["relevant_file"]
|
relevant_file = suggestion["relevant_file"]
|
||||||
@ -129,15 +142,22 @@ class BitbucketProvider(GitProvider):
|
|||||||
self.publish_inline_comments(post_parameters_list)
|
self.publish_inline_comments(post_parameters_list)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Bitbucket failed to publish code suggestion, error: {e}")
|
get_logger().error(
|
||||||
|
f"Bitbucket failed to publish code suggestion, error: {e}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def publish_file_comments(self, file_comments: list) -> bool:
|
def publish_file_comments(self, file_comments: list) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_supported(self, capability: str) -> bool:
|
def is_supported(self, capability: str) -> bool:
|
||||||
if capability in ['get_issue_comments', 'publish_inline_comments', 'get_labels', 'gfm_markdown',
|
if capability in [
|
||||||
'publish_file_comments']:
|
'get_issue_comments',
|
||||||
|
'publish_inline_comments',
|
||||||
|
'get_labels',
|
||||||
|
'gfm_markdown',
|
||||||
|
'publish_file_comments',
|
||||||
|
]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -169,12 +189,14 @@ class BitbucketProvider(GitProvider):
|
|||||||
names_original = [d.new.path for d in diffs_original]
|
names_original = [d.new.path for d in diffs_original]
|
||||||
names_kept = [d.new.path for d in diffs]
|
names_kept = [d.new.path for d in diffs]
|
||||||
names_filtered = list(set(names_original) - set(names_kept))
|
names_filtered = list(set(names_original) - set(names_kept))
|
||||||
get_logger().info(f"Filtered out [ignore] files for PR", extra={
|
get_logger().info(
|
||||||
'original_files': names_original,
|
f"Filtered out [ignore] files for PR",
|
||||||
'names_kept': names_kept,
|
extra={
|
||||||
'names_filtered': names_filtered
|
'original_files': names_original,
|
||||||
|
'names_kept': names_kept,
|
||||||
})
|
'names_filtered': names_filtered,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -189,20 +211,32 @@ class BitbucketProvider(GitProvider):
|
|||||||
for encoding in encodings_to_try:
|
for encoding in encodings_to_try:
|
||||||
try:
|
try:
|
||||||
pr_patches = self.pr.diff(encoding=encoding)
|
pr_patches = self.pr.diff(encoding=encoding)
|
||||||
get_logger().info(f"Successfully decoded PR patch with encoding {encoding}")
|
get_logger().info(
|
||||||
|
f"Successfully decoded PR patch with encoding {encoding}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if pr_patches is None:
|
if pr_patches is None:
|
||||||
raise ValueError(f"Failed to decode PR patch with encodings {encodings_to_try}")
|
raise ValueError(
|
||||||
|
f"Failed to decode PR patch with encodings {encodings_to_try}"
|
||||||
|
)
|
||||||
|
|
||||||
diff_split = ["diff --git" + x for x in pr_patches.split("diff --git") if x.strip()]
|
diff_split = [
|
||||||
|
"diff --git" + x for x in pr_patches.split("diff --git") if x.strip()
|
||||||
|
]
|
||||||
# filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs'
|
# filter all elements of 'diff_split' that are of indices in 'diffs_original' that are not in 'diffs'
|
||||||
if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split):
|
if len(diff_split) > len(diffs) and len(diffs_original) == len(diff_split):
|
||||||
diff_split = [diff_split[i] for i in range(len(diff_split)) if diffs_original[i] in diffs]
|
diff_split = [
|
||||||
|
diff_split[i]
|
||||||
|
for i in range(len(diff_split))
|
||||||
|
if diffs_original[i] in diffs
|
||||||
|
]
|
||||||
if len(diff_split) != len(diffs):
|
if len(diff_split) != len(diffs):
|
||||||
get_logger().error(f"Error - failed to split the diff into {len(diffs)} parts")
|
get_logger().error(
|
||||||
|
f"Error - failed to split the diff into {len(diffs)} parts"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
# bitbucket diff has a header for each file, we need to remove it:
|
# bitbucket diff has a header for each file, we need to remove it:
|
||||||
# "diff --git filename
|
# "diff --git filename
|
||||||
@ -213,22 +247,34 @@ class BitbucketProvider(GitProvider):
|
|||||||
# @@ -... @@"
|
# @@ -... @@"
|
||||||
for i, _ in enumerate(diff_split):
|
for i, _ in enumerate(diff_split):
|
||||||
diff_split_lines = diff_split[i].splitlines()
|
diff_split_lines = diff_split[i].splitlines()
|
||||||
if (len(diff_split_lines) >= 6) and \
|
if (len(diff_split_lines) >= 6) and (
|
||||||
((diff_split_lines[2].startswith("---") and
|
(
|
||||||
diff_split_lines[3].startswith("+++") and
|
diff_split_lines[2].startswith("---")
|
||||||
diff_split_lines[4].startswith("@@")) or
|
and diff_split_lines[3].startswith("+++")
|
||||||
(diff_split_lines[3].startswith("---") and # new or deleted file
|
and diff_split_lines[4].startswith("@@")
|
||||||
diff_split_lines[4].startswith("+++") and
|
)
|
||||||
diff_split_lines[5].startswith("@@"))):
|
or (
|
||||||
|
diff_split_lines[3].startswith("---")
|
||||||
|
and diff_split_lines[4].startswith("+++") # new or deleted file
|
||||||
|
and diff_split_lines[5].startswith("@@")
|
||||||
|
)
|
||||||
|
):
|
||||||
diff_split[i] = "\n".join(diff_split_lines[4:])
|
diff_split[i] = "\n".join(diff_split_lines[4:])
|
||||||
else:
|
else:
|
||||||
if diffs[i].data.get('lines_added', 0) == 0 and diffs[i].data.get('lines_removed', 0) == 0:
|
if (
|
||||||
|
diffs[i].data.get('lines_added', 0) == 0
|
||||||
|
and diffs[i].data.get('lines_removed', 0) == 0
|
||||||
|
):
|
||||||
diff_split[i] = ""
|
diff_split[i] = ""
|
||||||
elif len(diff_split_lines) <= 3:
|
elif len(diff_split_lines) <= 3:
|
||||||
diff_split[i] = ""
|
diff_split[i] = ""
|
||||||
get_logger().info(f"Disregarding empty diff for file {_gef_filename(diffs[i])}")
|
get_logger().info(
|
||||||
|
f"Disregarding empty diff for file {_gef_filename(diffs[i])}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().warning(f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}")
|
get_logger().warning(
|
||||||
|
f"Bitbucket failed to get diff for file {_gef_filename(diffs[i])}"
|
||||||
|
)
|
||||||
diff_split[i] = ""
|
diff_split[i] = ""
|
||||||
|
|
||||||
invalid_files_names = []
|
invalid_files_names = []
|
||||||
@ -246,24 +292,32 @@ class BitbucketProvider(GitProvider):
|
|||||||
if get_settings().get("bitbucket_app.avoid_full_files", False):
|
if get_settings().get("bitbucket_app.avoid_full_files", False):
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
elif counter_valid < MAX_FILES_ALLOWED_FULL // 2: # factor 2 because bitbucket has limited API calls
|
elif (
|
||||||
|
counter_valid < MAX_FILES_ALLOWED_FULL // 2
|
||||||
|
): # factor 2 because bitbucket has limited API calls
|
||||||
if diff.old.get_data("links"):
|
if diff.old.get_data("links"):
|
||||||
original_file_content_str = self._get_pr_file_content(
|
original_file_content_str = self._get_pr_file_content(
|
||||||
diff.old.get_data("links")['self']['href'])
|
diff.old.get_data("links")['self']['href']
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
if diff.new.get_data("links"):
|
if diff.new.get_data("links"):
|
||||||
new_file_content_str = self._get_pr_file_content(diff.new.get_data("links")['self']['href'])
|
new_file_content_str = self._get_pr_file_content(
|
||||||
|
diff.new.get_data("links")['self']['href']
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
else:
|
else:
|
||||||
if counter_valid == MAX_FILES_ALLOWED_FULL // 2:
|
if counter_valid == MAX_FILES_ALLOWED_FULL // 2:
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Bitbucket too many files in PR, will avoid loading full content for rest of files")
|
f"Bitbucket too many files in PR, will avoid loading full content for rest of files"
|
||||||
|
)
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Error - bitbucket failed to get file content, error: {e}")
|
get_logger().exception(
|
||||||
|
f"Error - bitbucket failed to get file content, error: {e}"
|
||||||
|
)
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
|
|
||||||
@ -285,7 +339,9 @@ class BitbucketProvider(GitProvider):
|
|||||||
diff_files.append(file_patch_canonic_structure)
|
diff_files.append(file_patch_canonic_structure)
|
||||||
|
|
||||||
if invalid_files_names:
|
if invalid_files_names:
|
||||||
get_logger().info(f"Disregarding files with invalid extensions:\n{invalid_files_names}")
|
get_logger().info(
|
||||||
|
f"Disregarding files with invalid extensions:\n{invalid_files_names}"
|
||||||
|
)
|
||||||
|
|
||||||
self.diff_files = diff_files
|
self.diff_files = diff_files
|
||||||
return diff_files
|
return diff_files
|
||||||
@ -296,11 +352,14 @@ class BitbucketProvider(GitProvider):
|
|||||||
def get_comment_url(self, comment):
|
def get_comment_url(self, comment):
|
||||||
return comment.data['links']['html']['href']
|
return comment.data['links']['html']['href']
|
||||||
|
|
||||||
def publish_persistent_comment(self, pr_comment: str,
|
def publish_persistent_comment(
|
||||||
initial_header: str,
|
self,
|
||||||
update_header: bool = True,
|
pr_comment: str,
|
||||||
name='review',
|
initial_header: str,
|
||||||
final_update_message=True):
|
update_header: bool = True,
|
||||||
|
name='review',
|
||||||
|
final_update_message=True,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
for comment in self.pr.comments():
|
for comment in self.pr.comments():
|
||||||
body = comment.raw
|
body = comment.raw
|
||||||
@ -309,15 +368,20 @@ class BitbucketProvider(GitProvider):
|
|||||||
comment_url = self.get_comment_url(comment)
|
comment_url = self.get_comment_url(comment)
|
||||||
if update_header:
|
if update_header:
|
||||||
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
||||||
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
|
pr_comment_updated = pr_comment.replace(
|
||||||
|
initial_header, updated_header
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pr_comment_updated = pr_comment
|
pr_comment_updated = pr_comment
|
||||||
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
|
get_logger().info(
|
||||||
|
f"Persistent mode - updating comment {comment_url} to latest {name} message"
|
||||||
|
)
|
||||||
d = {"content": {"raw": pr_comment_updated}}
|
d = {"content": {"raw": pr_comment_updated}}
|
||||||
response = comment._update_data(comment.put(None, data=d))
|
response = comment._update_data(comment.put(None, data=d))
|
||||||
if final_update_message:
|
if final_update_message:
|
||||||
self.publish_comment(
|
self.publish_comment(
|
||||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
|
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||||
@ -326,7 +390,9 @@ class BitbucketProvider(GitProvider):
|
|||||||
|
|
||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
if is_temporary and not get_settings().config.publish_output_progress:
|
if is_temporary and not get_settings().config.publish_output_progress:
|
||||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
get_logger().debug(
|
||||||
|
f"Skipping publish_comment for temporary comment: {pr_comment}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length)
|
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_length)
|
||||||
comment = self.pr.comment(pr_comment)
|
comment = self.pr.comment(pr_comment)
|
||||||
@ -355,39 +421,58 @@ class BitbucketProvider(GitProvider):
|
|||||||
get_logger().exception(f"Failed to remove comment, error: {e}")
|
get_logger().exception(f"Failed to remove comment, error: {e}")
|
||||||
|
|
||||||
# function to create_inline_comment
|
# function to create_inline_comment
|
||||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
def create_inline_comment(
|
||||||
absolute_position: int = None):
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
absolute_position: int = None,
|
||||||
|
):
|
||||||
body = self.limit_output_characters(body, self.max_comment_length)
|
body = self.limit_output_characters(body, self.max_comment_length)
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.get_diff_files(),
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
relevant_file.strip('`'),
|
self.get_diff_files(),
|
||||||
relevant_line_in_file,
|
relevant_file.strip('`'),
|
||||||
absolute_position)
|
relevant_line_in_file,
|
||||||
|
absolute_position,
|
||||||
|
)
|
||||||
if position == -1:
|
if position == -1:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
get_logger().info(
|
||||||
|
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||||
|
)
|
||||||
subject_type = "FILE"
|
subject_type = "FILE"
|
||||||
else:
|
else:
|
||||||
subject_type = "LINE"
|
subject_type = "LINE"
|
||||||
path = relevant_file.strip()
|
path = relevant_file.strip()
|
||||||
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
|
return (
|
||||||
|
dict(body=body, path=path, position=absolute_position)
|
||||||
|
if subject_type == "LINE"
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
|
self, comment: str, from_line: int, file: str, original_suggestion=None
|
||||||
|
):
|
||||||
comment = self.limit_output_characters(comment, self.max_comment_length)
|
comment = self.limit_output_characters(comment, self.max_comment_length)
|
||||||
payload = json.dumps({
|
payload = json.dumps(
|
||||||
"content": {
|
{
|
||||||
"raw": comment,
|
"content": {
|
||||||
},
|
"raw": comment,
|
||||||
"inline": {
|
},
|
||||||
"to": from_line,
|
"inline": {"to": from_line, "path": file},
|
||||||
"path": file
|
}
|
||||||
},
|
)
|
||||||
})
|
|
||||||
response = requests.request(
|
response = requests.request(
|
||||||
"POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers
|
"POST", self.bitbucket_comment_api_url, data=payload, headers=self.headers
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
def get_line_link(
|
||||||
|
self,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_start: int,
|
||||||
|
relevant_line_end: int = None,
|
||||||
|
) -> str:
|
||||||
if relevant_line_start == -1:
|
if relevant_line_start == -1:
|
||||||
link = f"{self.pr_url}/#L{relevant_file}"
|
link = f"{self.pr_url}/#L{relevant_file}"
|
||||||
else:
|
else:
|
||||||
@ -402,8 +487,9 @@ class BitbucketProvider(GitProvider):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
diff_files = self.get_diff_files()
|
diff_files = self.get_diff_files()
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
(diff_files, relevant_file, relevant_line_str)
|
diff_files, relevant_file, relevant_line_str
|
||||||
|
)
|
||||||
|
|
||||||
if absolute_position != -1 and self.pr_url:
|
if absolute_position != -1 and self.pr_url:
|
||||||
link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}"
|
link = f"{self.pr_url}/#L{relevant_file}T{absolute_position}"
|
||||||
@ -417,12 +503,18 @@ class BitbucketProvider(GitProvider):
|
|||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
for comment in comments:
|
for comment in comments:
|
||||||
if 'position' in comment:
|
if 'position' in comment:
|
||||||
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
|
self.publish_inline_comment(
|
||||||
|
comment['body'], comment['position'], comment['path']
|
||||||
|
)
|
||||||
elif 'start_line' in comment: # multi-line comment
|
elif 'start_line' in comment: # multi-line comment
|
||||||
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
||||||
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
|
self.publish_inline_comment(
|
||||||
|
comment['body'], comment['start_line'], comment['path']
|
||||||
|
)
|
||||||
elif 'line' in comment: # single-line comment
|
elif 'line' in comment: # single-line comment
|
||||||
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
|
self.publish_inline_comment(
|
||||||
|
comment['body'], comment['line'], comment['path']
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().error(f"Could not publish inline comment {comment}")
|
get_logger().error(f"Could not publish inline comment {comment}")
|
||||||
|
|
||||||
@ -450,7 +542,9 @@ class BitbucketProvider(GitProvider):
|
|||||||
"Bitbucket provider does not support issue comments yet"
|
"Bitbucket provider does not support issue comments yet"
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
def add_eyes_reaction(
|
||||||
|
self, issue_comment_id: int, disable_eyes: bool = False
|
||||||
|
) -> Optional[int]:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||||
@ -495,8 +589,10 @@ class BitbucketProvider(GitProvider):
|
|||||||
branch = self.pr.data["source"]["commit"]["hash"]
|
branch = self.pr.data["source"]["commit"]["hash"]
|
||||||
elif branch == self.pr.destination_branch:
|
elif branch == self.pr.destination_branch:
|
||||||
branch = self.pr.data["destination"]["commit"]["hash"]
|
branch = self.pr.data["destination"]["commit"]["hash"]
|
||||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
url = (
|
||||||
f"{branch}/{file_path}")
|
f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||||
|
f"{branch}/{file_path}"
|
||||||
|
)
|
||||||
response = requests.request("GET", url, headers=self.headers)
|
response = requests.request("GET", url, headers=self.headers)
|
||||||
if response.status_code == 404: # not found
|
if response.status_code == 404: # not found
|
||||||
return ""
|
return ""
|
||||||
@ -505,23 +601,28 @@ class BitbucketProvider(GitProvider):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None:
|
def create_or_update_pr_file(
|
||||||
url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/")
|
self, file_path: str, branch: str, contents="", message=""
|
||||||
|
) -> None:
|
||||||
|
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/"
|
||||||
if not message:
|
if not message:
|
||||||
if contents:
|
if contents:
|
||||||
message = f"Update {file_path}"
|
message = f"Update {file_path}"
|
||||||
else:
|
else:
|
||||||
message = f"Create {file_path}"
|
message = f"Create {file_path}"
|
||||||
files = {file_path: contents}
|
files = {file_path: contents}
|
||||||
data = {
|
data = {"message": message, "branch": branch}
|
||||||
"message": message,
|
headers = (
|
||||||
"branch": branch
|
{'Authorization': self.headers['Authorization']}
|
||||||
}
|
if 'Authorization' in self.headers
|
||||||
headers = {'Authorization': self.headers['Authorization']} if 'Authorization' in self.headers else {}
|
else {}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
requests.request("POST", url, headers=headers, data=data, files=files)
|
requests.request("POST", url, headers=headers, data=data, files=files)
|
||||||
except Exception:
|
except Exception:
|
||||||
get_logger().exception(f"Failed to create empty file {file_path} in branch {branch}")
|
get_logger().exception(
|
||||||
|
f"Failed to create empty file {file_path} in branch {branch}"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_pr_file_content(self, remote_link: str):
|
def _get_pr_file_content(self, remote_link: str):
|
||||||
try:
|
try:
|
||||||
@ -538,16 +639,19 @@ class BitbucketProvider(GitProvider):
|
|||||||
|
|
||||||
# bitbucket does not support labels
|
# bitbucket does not support labels
|
||||||
def publish_description(self, pr_title: str, description: str):
|
def publish_description(self, pr_title: str, description: str):
|
||||||
payload = json.dumps({
|
payload = json.dumps({"description": description, "title": pr_title})
|
||||||
"description": description,
|
|
||||||
"title": pr_title
|
|
||||||
|
|
||||||
})
|
response = requests.request(
|
||||||
|
"PUT",
|
||||||
response = requests.request("PUT", self.bitbucket_pull_request_api_url, headers=self.headers, data=payload)
|
self.bitbucket_pull_request_api_url,
|
||||||
|
headers=self.headers,
|
||||||
|
data=payload,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
get_logger().info(f"Failed to update description, error code: {response.status_code}")
|
get_logger().info(
|
||||||
|
f"Failed to update description, error code: {response.status_code}"
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -11,8 +11,7 @@ from requests.exceptions import HTTPError
|
|||||||
from ..algo.git_patch_processing import decode_if_bytes
|
from ..algo.git_patch_processing import decode_if_bytes
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
from ..algo.types import EDIT_TYPE, FilePatchInfo
|
from ..algo.types import EDIT_TYPE, FilePatchInfo
|
||||||
from ..algo.utils import (find_line_number_of_relevant_line_in_file,
|
from ..algo.utils import find_line_number_of_relevant_line_in_file, load_large_diff
|
||||||
load_large_diff)
|
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from ..log import get_logger
|
from ..log import get_logger
|
||||||
from .git_provider import GitProvider
|
from .git_provider import GitProvider
|
||||||
@ -20,8 +19,10 @@ from .git_provider import GitProvider
|
|||||||
|
|
||||||
class BitbucketServerProvider(GitProvider):
|
class BitbucketServerProvider(GitProvider):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False,
|
self,
|
||||||
bitbucket_client: Optional[Bitbucket] = None,
|
pr_url: Optional[str] = None,
|
||||||
|
incremental: Optional[bool] = False,
|
||||||
|
bitbucket_client: Optional[Bitbucket] = None,
|
||||||
):
|
):
|
||||||
self.bitbucket_server_url = None
|
self.bitbucket_server_url = None
|
||||||
self.workspace_slug = None
|
self.workspace_slug = None
|
||||||
@ -36,11 +37,16 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
self.bitbucket_pull_request_api_url = pr_url
|
self.bitbucket_pull_request_api_url = pr_url
|
||||||
|
|
||||||
self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url)
|
self.bitbucket_server_url = self._parse_bitbucket_server(url=pr_url)
|
||||||
self.bitbucket_client = bitbucket_client or Bitbucket(url=self.bitbucket_server_url,
|
self.bitbucket_client = bitbucket_client or Bitbucket(
|
||||||
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN",
|
url=self.bitbucket_server_url,
|
||||||
None))
|
token=get_settings().get("BITBUCKET_SERVER.BEARER_TOKEN", None),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
self.bitbucket_api_version = parse_version(self.bitbucket_client.get("rest/api/1.0/application-properties").get('version'))
|
self.bitbucket_api_version = parse_version(
|
||||||
|
self.bitbucket_client.get("rest/api/1.0/application-properties").get(
|
||||||
|
'version'
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.bitbucket_api_version = None
|
self.bitbucket_api_version = None
|
||||||
|
|
||||||
@ -49,7 +55,12 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
|
|
||||||
def get_repo_settings(self):
|
def get_repo_settings(self):
|
||||||
try:
|
try:
|
||||||
content = self.bitbucket_client.get_content_of_file(self.workspace_slug, self.repo_slug, ".pr_agent.toml", self.get_pr_branch())
|
content = self.bitbucket_client.get_content_of_file(
|
||||||
|
self.workspace_slug,
|
||||||
|
self.repo_slug,
|
||||||
|
".pr_agent.toml",
|
||||||
|
self.get_pr_branch(),
|
||||||
|
)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -70,20 +81,27 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
post_parameters_list = []
|
post_parameters_list = []
|
||||||
for suggestion in code_suggestions:
|
for suggestion in code_suggestions:
|
||||||
body = suggestion["body"]
|
body = suggestion["body"]
|
||||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
original_suggestion = suggestion.get(
|
||||||
|
'original_suggestion', None
|
||||||
|
) # needed for diff code
|
||||||
if original_suggestion:
|
if original_suggestion:
|
||||||
try:
|
try:
|
||||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
||||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
||||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
diff = difflib.unified_diff(
|
||||||
improved_code.split('\n'), n=999)
|
existing_code.split('\n'), improved_code.split('\n'), n=999
|
||||||
|
)
|
||||||
patch_orig = "\n".join(diff)
|
patch_orig = "\n".join(diff)
|
||||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
body = re.sub(
|
||||||
|
r'```suggestion.*?```', diff_code, body, flags=re.DOTALL
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Bitbucket failed to get diff code for publishing, error: {e}")
|
get_logger().exception(
|
||||||
|
f"Bitbucket failed to get diff code for publishing, error: {e}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
relevant_file = suggestion["relevant_file"]
|
relevant_file = suggestion["relevant_file"]
|
||||||
relevant_lines_start = suggestion["relevant_lines_start"]
|
relevant_lines_start = suggestion["relevant_lines_start"]
|
||||||
@ -134,7 +152,12 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def is_supported(self, capability: str) -> bool:
|
def is_supported(self, capability: str) -> bool:
|
||||||
if capability in ['get_issue_comments', 'get_labels', 'gfm_markdown', 'publish_file_comments']:
|
if capability in [
|
||||||
|
'get_issue_comments',
|
||||||
|
'get_labels',
|
||||||
|
'gfm_markdown',
|
||||||
|
'publish_file_comments',
|
||||||
|
]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -145,23 +168,28 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
def get_file(self, path: str, commit_id: str):
|
def get_file(self, path: str, commit_id: str):
|
||||||
file_content = ""
|
file_content = ""
|
||||||
try:
|
try:
|
||||||
file_content = self.bitbucket_client.get_content_of_file(self.workspace_slug,
|
file_content = self.bitbucket_client.get_content_of_file(
|
||||||
self.repo_slug,
|
self.workspace_slug, self.repo_slug, path, commit_id
|
||||||
path,
|
)
|
||||||
commit_id)
|
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
get_logger().debug(f"File {path} not found at commit id: {commit_id}")
|
get_logger().debug(f"File {path} not found at commit id: {commit_id}")
|
||||||
return file_content
|
return file_content
|
||||||
|
|
||||||
def get_files(self):
|
def get_files(self):
|
||||||
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
|
changes = self.bitbucket_client.get_pull_requests_changes(
|
||||||
|
self.workspace_slug, self.repo_slug, self.pr_num
|
||||||
|
)
|
||||||
diffstat = [change["path"]['toString'] for change in changes]
|
diffstat = [change["path"]['toString'] for change in changes]
|
||||||
return diffstat
|
return diffstat
|
||||||
|
|
||||||
#gets the best common ancestor: https://git-scm.com/docs/git-merge-base
|
# gets the best common ancestor: https://git-scm.com/docs/git-merge-base
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_best_common_ancestor(source_commits_list, destination_commits_list, guaranteed_common_ancestor) -> str:
|
def get_best_common_ancestor(
|
||||||
destination_commit_hashes = {commit['id'] for commit in destination_commits_list} | {guaranteed_common_ancestor}
|
source_commits_list, destination_commits_list, guaranteed_common_ancestor
|
||||||
|
) -> str:
|
||||||
|
destination_commit_hashes = {
|
||||||
|
commit['id'] for commit in destination_commits_list
|
||||||
|
} | {guaranteed_common_ancestor}
|
||||||
|
|
||||||
for commit in source_commits_list:
|
for commit in source_commits_list:
|
||||||
for parent_commit in commit['parents']:
|
for parent_commit in commit['parents']:
|
||||||
@ -177,37 +205,55 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
head_sha = self.pr.fromRef['latestCommit']
|
head_sha = self.pr.fromRef['latestCommit']
|
||||||
|
|
||||||
# if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation
|
# if Bitbucket api version is >= 8.16 then use the merge-base api for 2-way diff calculation
|
||||||
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("8.16"):
|
if (
|
||||||
|
self.bitbucket_api_version is not None
|
||||||
|
and self.bitbucket_api_version >= parse_version("8.16")
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
base_sha = self.bitbucket_client.get(self._get_merge_base())['id']
|
base_sha = self.bitbucket_client.get(self._get_merge_base())['id']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to get the best common ancestor for PR: {self.pr_url}, \nerror: {e}"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
source_commits_list = list(self.bitbucket_client.get_pull_requests_commits(
|
source_commits_list = list(
|
||||||
self.workspace_slug,
|
self.bitbucket_client.get_pull_requests_commits(
|
||||||
self.repo_slug,
|
self.workspace_slug, self.repo_slug, self.pr_num
|
||||||
self.pr_num
|
)
|
||||||
))
|
)
|
||||||
# if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor
|
# if Bitbucket api version is None or < 7.0 then do a simple diff with a guaranteed common ancestor
|
||||||
base_sha = source_commits_list[-1]['parents'][0]['id']
|
base_sha = source_commits_list[-1]['parents'][0]['id']
|
||||||
# if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha
|
# if Bitbucket api version is 7.0-8.15 then use 2-way diff functionality for the base_sha
|
||||||
if self.bitbucket_api_version is not None and self.bitbucket_api_version >= parse_version("7.0"):
|
if (
|
||||||
|
self.bitbucket_api_version is not None
|
||||||
|
and self.bitbucket_api_version >= parse_version("7.0")
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
destination_commits = list(
|
destination_commits = list(
|
||||||
self.bitbucket_client.get_commits(self.workspace_slug, self.repo_slug, base_sha,
|
self.bitbucket_client.get_commits(
|
||||||
self.pr.toRef['latestCommit']))
|
self.workspace_slug,
|
||||||
base_sha = self.get_best_common_ancestor(source_commits_list, destination_commits, base_sha)
|
self.repo_slug,
|
||||||
|
base_sha,
|
||||||
|
self.pr.toRef['latestCommit'],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
base_sha = self.get_best_common_ancestor(
|
||||||
|
source_commits_list, destination_commits, base_sha
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(
|
get_logger().error(
|
||||||
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}")
|
f"Failed to get the commit list for calculating best common ancestor for PR: {self.pr_url}, \nerror: {e}"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
diff_files = []
|
diff_files = []
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
|
|
||||||
changes = self.bitbucket_client.get_pull_requests_changes(self.workspace_slug, self.repo_slug, self.pr_num)
|
changes = self.bitbucket_client.get_pull_requests_changes(
|
||||||
|
self.workspace_slug, self.repo_slug, self.pr_num
|
||||||
|
)
|
||||||
for change in changes:
|
for change in changes:
|
||||||
file_path = change['path']['toString']
|
file_path = change['path']['toString']
|
||||||
if not is_valid_file(file_path.split("/")[-1]):
|
if not is_valid_file(file_path.split("/")[-1]):
|
||||||
@ -224,17 +270,26 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
edit_type = EDIT_TYPE.DELETED
|
edit_type = EDIT_TYPE.DELETED
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
original_file_content_str = self.get_file(file_path, base_sha)
|
original_file_content_str = self.get_file(file_path, base_sha)
|
||||||
original_file_content_str = decode_if_bytes(original_file_content_str)
|
original_file_content_str = decode_if_bytes(
|
||||||
|
original_file_content_str
|
||||||
|
)
|
||||||
case 'RENAME':
|
case 'RENAME':
|
||||||
edit_type = EDIT_TYPE.RENAMED
|
edit_type = EDIT_TYPE.RENAMED
|
||||||
case _:
|
case _:
|
||||||
edit_type = EDIT_TYPE.MODIFIED
|
edit_type = EDIT_TYPE.MODIFIED
|
||||||
original_file_content_str = self.get_file(file_path, base_sha)
|
original_file_content_str = self.get_file(file_path, base_sha)
|
||||||
original_file_content_str = decode_if_bytes(original_file_content_str)
|
original_file_content_str = decode_if_bytes(
|
||||||
|
original_file_content_str
|
||||||
|
)
|
||||||
new_file_content_str = self.get_file(file_path, head_sha)
|
new_file_content_str = self.get_file(file_path, head_sha)
|
||||||
new_file_content_str = decode_if_bytes(new_file_content_str)
|
new_file_content_str = decode_if_bytes(new_file_content_str)
|
||||||
|
|
||||||
patch = load_large_diff(file_path, new_file_content_str, original_file_content_str, show_warning=False)
|
patch = load_large_diff(
|
||||||
|
file_path,
|
||||||
|
new_file_content_str,
|
||||||
|
original_file_content_str,
|
||||||
|
show_warning=False,
|
||||||
|
)
|
||||||
|
|
||||||
diff_files.append(
|
diff_files.append(
|
||||||
FilePatchInfo(
|
FilePatchInfo(
|
||||||
@ -251,7 +306,9 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
|
|
||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
if not is_temporary:
|
if not is_temporary:
|
||||||
self.bitbucket_client.add_pull_request_comment(self.workspace_slug, self.repo_slug, self.pr_num, pr_comment)
|
self.bitbucket_client.add_pull_request_comment(
|
||||||
|
self.workspace_slug, self.repo_slug, self.pr_num, pr_comment
|
||||||
|
)
|
||||||
|
|
||||||
def remove_initial_comment(self):
|
def remove_initial_comment(self):
|
||||||
try:
|
try:
|
||||||
@ -264,25 +321,37 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# function to create_inline_comment
|
# function to create_inline_comment
|
||||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
def create_inline_comment(
|
||||||
absolute_position: int = None):
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
absolute_position: int = None,
|
||||||
|
):
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
self.get_diff_files(),
|
self.get_diff_files(),
|
||||||
relevant_file.strip('`'),
|
relevant_file.strip('`'),
|
||||||
relevant_line_in_file,
|
relevant_line_in_file,
|
||||||
absolute_position
|
absolute_position,
|
||||||
)
|
)
|
||||||
if position == -1:
|
if position == -1:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
get_logger().info(
|
||||||
|
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||||
|
)
|
||||||
subject_type = "FILE"
|
subject_type = "FILE"
|
||||||
else:
|
else:
|
||||||
subject_type = "LINE"
|
subject_type = "LINE"
|
||||||
path = relevant_file.strip()
|
path = relevant_file.strip()
|
||||||
return dict(body=body, path=path, position=absolute_position) if subject_type == "LINE" else {}
|
return (
|
||||||
|
dict(body=body, path=path, position=absolute_position)
|
||||||
|
if subject_type == "LINE"
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
def publish_inline_comment(self, comment: str, from_line: int, file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
|
self, comment: str, from_line: int, file: str, original_suggestion=None
|
||||||
|
):
|
||||||
payload = {
|
payload = {
|
||||||
"text": comment,
|
"text": comment,
|
||||||
"severity": "NORMAL",
|
"severity": "NORMAL",
|
||||||
@ -291,17 +360,24 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
"path": file,
|
"path": file,
|
||||||
"lineType": "ADDED",
|
"lineType": "ADDED",
|
||||||
"line": from_line,
|
"line": from_line,
|
||||||
"fileType": "TO"
|
"fileType": "TO",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.bitbucket_client.post(self._get_pr_comments_path(), data=payload)
|
self.bitbucket_client.post(self._get_pr_comments_path(), data=payload)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to publish inline comment to '{file}' at line {from_line}, error: {e}"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
def get_line_link(
|
||||||
|
self,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_start: int,
|
||||||
|
relevant_line_end: int = None,
|
||||||
|
) -> str:
|
||||||
if relevant_line_start == -1:
|
if relevant_line_start == -1:
|
||||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}"
|
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}"
|
||||||
else:
|
else:
|
||||||
@ -316,8 +392,9 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
diff_files = self.get_diff_files()
|
diff_files = self.get_diff_files()
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
(diff_files, relevant_file, relevant_line_str)
|
diff_files, relevant_file, relevant_line_str
|
||||||
|
)
|
||||||
|
|
||||||
if absolute_position != -1:
|
if absolute_position != -1:
|
||||||
if self.pr:
|
if self.pr:
|
||||||
@ -325,29 +402,41 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
return link
|
return link
|
||||||
else:
|
else:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Failed adding line link to '{relevant_file}' since PR not set")
|
get_logger().info(
|
||||||
|
f"Failed adding line link to '{relevant_file}' since PR not set"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Failed adding line link to '{relevant_file}' since position not found")
|
get_logger().info(
|
||||||
|
f"Failed adding line link to '{relevant_file}' since position not found"
|
||||||
|
)
|
||||||
|
|
||||||
if absolute_position != -1 and self.pr_url:
|
if absolute_position != -1 and self.pr_url:
|
||||||
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
|
link = f"{self.pr_url}/diff#{quote_plus(relevant_file)}?t={absolute_position}"
|
||||||
return link
|
return link
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Failed adding line link to '{relevant_file}', error: {e}")
|
get_logger().info(
|
||||||
|
f"Failed adding line link to '{relevant_file}', error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
for comment in comments:
|
for comment in comments:
|
||||||
if 'position' in comment:
|
if 'position' in comment:
|
||||||
self.publish_inline_comment(comment['body'], comment['position'], comment['path'])
|
self.publish_inline_comment(
|
||||||
elif 'start_line' in comment: # multi-line comment
|
comment['body'], comment['position'], comment['path']
|
||||||
|
)
|
||||||
|
elif 'start_line' in comment: # multi-line comment
|
||||||
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
# note that bitbucket does not seem to support range - only a comment on a single line - https://community.developer.atlassian.com/t/api-post-endpoint-for-inline-pull-request-comments/60452
|
||||||
self.publish_inline_comment(comment['body'], comment['start_line'], comment['path'])
|
self.publish_inline_comment(
|
||||||
elif 'line' in comment: # single-line comment
|
comment['body'], comment['start_line'], comment['path']
|
||||||
self.publish_inline_comment(comment['body'], comment['line'], comment['path'])
|
)
|
||||||
|
elif 'line' in comment: # single-line comment
|
||||||
|
self.publish_inline_comment(
|
||||||
|
comment['body'], comment['line'], comment['path']
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().error(f"Could not publish inline comment: {comment}")
|
get_logger().error(f"Could not publish inline comment: {comment}")
|
||||||
|
|
||||||
@ -377,7 +466,9 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
"Bitbucket provider does not support issue comments yet"
|
"Bitbucket provider does not support issue comments yet"
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
def add_eyes_reaction(
|
||||||
|
self, issue_comment_id: int, disable_eyes: bool = False
|
||||||
|
) -> Optional[int]:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||||
@ -411,14 +502,20 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
users_index = -1
|
users_index = -1
|
||||||
|
|
||||||
if projects_index == -1 and users_index == -1:
|
if projects_index == -1 and users_index == -1:
|
||||||
raise ValueError(f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL")
|
raise ValueError(
|
||||||
|
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
|
||||||
|
)
|
||||||
|
|
||||||
if projects_index != -1:
|
if projects_index != -1:
|
||||||
path_parts = path_parts[projects_index:]
|
path_parts = path_parts[projects_index:]
|
||||||
else:
|
else:
|
||||||
path_parts = path_parts[users_index:]
|
path_parts = path_parts[users_index:]
|
||||||
|
|
||||||
if len(path_parts) < 6 or path_parts[2] != "repos" or path_parts[4] != "pull-requests":
|
if (
|
||||||
|
len(path_parts) < 6
|
||||||
|
or path_parts[2] != "repos"
|
||||||
|
or path_parts[4] != "pull-requests"
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
|
f"The provided URL '{pr_url}' does not appear to be a Bitbucket PR URL"
|
||||||
)
|
)
|
||||||
@ -430,19 +527,24 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
try:
|
try:
|
||||||
pr_number = int(path_parts[5])
|
pr_number = int(path_parts[5])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(f"Unable to convert PR number '{path_parts[5]}' to integer") from e
|
raise ValueError(
|
||||||
|
f"Unable to convert PR number '{path_parts[5]}' to integer"
|
||||||
|
) from e
|
||||||
|
|
||||||
return workspace_slug, repo_slug, pr_number
|
return workspace_slug, repo_slug, pr_number
|
||||||
|
|
||||||
def _get_repo(self):
|
def _get_repo(self):
|
||||||
if self.repo is None:
|
if self.repo is None:
|
||||||
self.repo = self.bitbucket_client.get_repo(self.workspace_slug, self.repo_slug)
|
self.repo = self.bitbucket_client.get_repo(
|
||||||
|
self.workspace_slug, self.repo_slug
|
||||||
|
)
|
||||||
return self.repo
|
return self.repo
|
||||||
|
|
||||||
def _get_pr(self):
|
def _get_pr(self):
|
||||||
try:
|
try:
|
||||||
pr = self.bitbucket_client.get_pull_request(self.workspace_slug, self.repo_slug,
|
pr = self.bitbucket_client.get_pull_request(
|
||||||
pull_request_id=self.pr_num)
|
self.workspace_slug, self.repo_slug, pull_request_id=self.pr_num
|
||||||
|
)
|
||||||
return type('new_dict', (object,), pr)
|
return type('new_dict', (object,), pr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to get pull request, error: {e}")
|
get_logger().error(f"Failed to get pull request, error: {e}")
|
||||||
@ -460,10 +562,12 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
"version": self.pr.version,
|
"version": self.pr.version,
|
||||||
"description": description,
|
"description": description,
|
||||||
"title": pr_title,
|
"title": pr_title,
|
||||||
"reviewers": self.pr.reviewers # needs to be sent otherwise gets wiped
|
"reviewers": self.pr.reviewers, # needs to be sent otherwise gets wiped
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
self.bitbucket_client.update_pull_request(self.workspace_slug, self.repo_slug, str(self.pr_num), payload)
|
self.bitbucket_client.update_pull_request(
|
||||||
|
self.workspace_slug, self.repo_slug, str(self.pr_num), payload
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to update pull request, error: {e}")
|
get_logger().error(f"Failed to update pull request, error: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -31,7 +31,9 @@ class CodeCommitPullRequestResponse:
|
|||||||
|
|
||||||
self.targets = []
|
self.targets = []
|
||||||
for target in json.get("pullRequestTargets", []):
|
for target in json.get("pullRequestTargets", []):
|
||||||
self.targets.append(CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target))
|
self.targets.append(
|
||||||
|
CodeCommitPullRequestResponse.CodeCommitPullRequestTarget(target)
|
||||||
|
)
|
||||||
|
|
||||||
class CodeCommitPullRequestTarget:
|
class CodeCommitPullRequestTarget:
|
||||||
"""
|
"""
|
||||||
@ -65,7 +67,9 @@ class CodeCommitClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e
|
raise ValueError(f"Failed to connect to AWS CodeCommit: {e}") from e
|
||||||
|
|
||||||
def get_differences(self, repo_name: int, destination_commit: str, source_commit: str):
|
def get_differences(
|
||||||
|
self, repo_name: int, destination_commit: str, source_commit: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get the differences between two commits in CodeCommit.
|
Get the differences between two commits in CodeCommit.
|
||||||
|
|
||||||
@ -96,17 +100,25 @@ class CodeCommitClient:
|
|||||||
differences.extend(page.get("differences", []))
|
differences.extend(page.get("differences", []))
|
||||||
except botocore.exceptions.ClientError as e:
|
except botocore.exceptions.ClientError as e:
|
||||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||||
raise ValueError(f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}") from e
|
raise ValueError(
|
||||||
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
|
f"CodeCommit cannot retrieve differences: Repository does not exist: {repo_name}"
|
||||||
|
) from e
|
||||||
|
raise ValueError(
|
||||||
|
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
|
||||||
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit cannot retrieve differences for {source_commit}..{destination_commit}"
|
||||||
|
) from e
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for json in differences:
|
for json in differences:
|
||||||
output.append(CodeCommitDifferencesResponse(json))
|
output.append(CodeCommitDifferencesResponse(json))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_file(self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False):
|
def get_file(
|
||||||
|
self, repo_name: str, file_path: str, sha_hash: str, optional: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve a file from CodeCommit.
|
Retrieve a file from CodeCommit.
|
||||||
|
|
||||||
@ -129,16 +141,24 @@ class CodeCommitClient:
|
|||||||
self._connect_boto_client()
|
self._connect_boto_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.boto_client.get_file(repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path)
|
response = self.boto_client.get_file(
|
||||||
|
repositoryName=repo_name, commitSpecifier=sha_hash, filePath=file_path
|
||||||
|
)
|
||||||
except botocore.exceptions.ClientError as e:
|
except botocore.exceptions.ClientError as e:
|
||||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||||
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
|
||||||
|
) from e
|
||||||
# if the file does not exist, but is flagged as optional, then return an empty string
|
# if the file does not exist, but is flagged as optional, then return an empty string
|
||||||
if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException':
|
if optional and e.response["Error"]["Code"] == 'FileDoesNotExistException':
|
||||||
return ""
|
return ""
|
||||||
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
|
||||||
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit cannot retrieve file '{file_path}' from repository '{repo_name}'"
|
||||||
|
) from e
|
||||||
if "fileContent" not in response:
|
if "fileContent" not in response:
|
||||||
raise ValueError(f"File content is empty for file: {file_path}")
|
raise ValueError(f"File content is empty for file: {file_path}")
|
||||||
|
|
||||||
@ -166,10 +186,16 @@ class CodeCommitClient:
|
|||||||
response = self.boto_client.get_pull_request(pullRequestId=str(pr_number))
|
response = self.boto_client.get_pull_request(pullRequestId=str(pr_number))
|
||||||
except botocore.exceptions.ClientError as e:
|
except botocore.exceptions.ClientError as e:
|
||||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||||
raise ValueError(f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit cannot retrieve PR: PR number does not exist: {pr_number}"
|
||||||
|
) from e
|
||||||
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
if e.response["Error"]["Code"] == 'RepositoryDoesNotExistException':
|
||||||
raise ValueError(f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}") from e
|
raise ValueError(
|
||||||
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}: boto client error") from e
|
f"CodeCommit cannot retrieve PR: Repository does not exist: {repo_name}"
|
||||||
|
) from e
|
||||||
|
raise ValueError(
|
||||||
|
f"CodeCommit cannot retrieve PR: {pr_number}: boto client error"
|
||||||
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e
|
raise ValueError(f"CodeCommit cannot retrieve PR: {pr_number}") from e
|
||||||
|
|
||||||
@ -200,22 +226,37 @@ class CodeCommitClient:
|
|||||||
self._connect_boto_client()
|
self._connect_boto_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.boto_client.update_pull_request_title(pullRequestId=str(pr_number), title=pr_title)
|
self.boto_client.update_pull_request_title(
|
||||||
self.boto_client.update_pull_request_description(pullRequestId=str(pr_number), description=pr_body)
|
pullRequestId=str(pr_number), title=pr_title
|
||||||
|
)
|
||||||
|
self.boto_client.update_pull_request_description(
|
||||||
|
pullRequestId=str(pr_number), description=pr_body
|
||||||
|
)
|
||||||
except botocore.exceptions.ClientError as e:
|
except botocore.exceptions.ClientError as e:
|
||||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||||
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
||||||
if e.response["Error"]["Code"] == 'InvalidTitleException':
|
if e.response["Error"]["Code"] == 'InvalidTitleException':
|
||||||
raise ValueError(f"Invalid title for PR number: {pr_number}") from e
|
raise ValueError(f"Invalid title for PR number: {pr_number}") from e
|
||||||
if e.response["Error"]["Code"] == 'InvalidDescriptionException':
|
if e.response["Error"]["Code"] == 'InvalidDescriptionException':
|
||||||
raise ValueError(f"Invalid description for PR number: {pr_number}") from e
|
raise ValueError(
|
||||||
|
f"Invalid description for PR number: {pr_number}"
|
||||||
|
) from e
|
||||||
if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException':
|
if e.response["Error"]["Code"] == 'PullRequestAlreadyClosedException':
|
||||||
raise ValueError(f"PR is already closed: PR number: {pr_number}") from e
|
raise ValueError(f"PR is already closed: PR number: {pr_number}") from e
|
||||||
raise ValueError(f"Boto3 client error calling publish_description") from e
|
raise ValueError(f"Boto3 client error calling publish_description") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error calling publish_description") from e
|
raise ValueError(f"Error calling publish_description") from e
|
||||||
|
|
||||||
def publish_comment(self, repo_name: str, pr_number: int, destination_commit: str, source_commit: str, comment: str, annotation_file: str = None, annotation_line: int = None):
|
def publish_comment(
|
||||||
|
self,
|
||||||
|
repo_name: str,
|
||||||
|
pr_number: int,
|
||||||
|
destination_commit: str,
|
||||||
|
source_commit: str,
|
||||||
|
comment: str,
|
||||||
|
annotation_file: str = None,
|
||||||
|
annotation_line: int = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Publish a comment to a pull request
|
Publish a comment to a pull request
|
||||||
|
|
||||||
@ -272,6 +313,8 @@ class CodeCommitClient:
|
|||||||
raise ValueError(f"Repository does not exist: {repo_name}") from e
|
raise ValueError(f"Repository does not exist: {repo_name}") from e
|
||||||
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
if e.response["Error"]["Code"] == 'PullRequestDoesNotExistException':
|
||||||
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
raise ValueError(f"PR number does not exist: {pr_number}") from e
|
||||||
raise ValueError(f"Boto3 client error calling post_comment_for_pull_request") from e
|
raise ValueError(
|
||||||
|
f"Boto3 client error calling post_comment_for_pull_request"
|
||||||
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error calling post_comment_for_pull_request") from e
|
raise ValueError(f"Error calling post_comment_for_pull_request") from e
|
||||||
|
|||||||
@ -55,7 +55,9 @@ class CodeCommitProvider(GitProvider):
|
|||||||
This class implements the GitProvider interface for AWS CodeCommit repositories.
|
This class implements the GitProvider interface for AWS CodeCommit repositories.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
def __init__(
|
||||||
|
self, pr_url: Optional[str] = None, incremental: Optional[bool] = False
|
||||||
|
):
|
||||||
self.codecommit_client = CodeCommitClient()
|
self.codecommit_client = CodeCommitClient()
|
||||||
self.aws_client = None
|
self.aws_client = None
|
||||||
self.repo_name = None
|
self.repo_name = None
|
||||||
@ -76,7 +78,7 @@ class CodeCommitProvider(GitProvider):
|
|||||||
"create_inline_comment",
|
"create_inline_comment",
|
||||||
"publish_inline_comments",
|
"publish_inline_comments",
|
||||||
"get_labels",
|
"get_labels",
|
||||||
"gfm_markdown"
|
"gfm_markdown",
|
||||||
]:
|
]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@ -91,13 +93,19 @@ class CodeCommitProvider(GitProvider):
|
|||||||
return self.git_files
|
return self.git_files
|
||||||
|
|
||||||
self.git_files = []
|
self.git_files = []
|
||||||
differences = self.codecommit_client.get_differences(self.repo_name, self.pr.destination_commit, self.pr.source_commit)
|
differences = self.codecommit_client.get_differences(
|
||||||
|
self.repo_name, self.pr.destination_commit, self.pr.source_commit
|
||||||
|
)
|
||||||
for item in differences:
|
for item in differences:
|
||||||
self.git_files.append(CodeCommitFile(item.before_blob_path,
|
self.git_files.append(
|
||||||
item.before_blob_id,
|
CodeCommitFile(
|
||||||
item.after_blob_path,
|
item.before_blob_path,
|
||||||
item.after_blob_id,
|
item.before_blob_id,
|
||||||
CodeCommitProvider._get_edit_type(item.change_type)))
|
item.after_blob_path,
|
||||||
|
item.after_blob_id,
|
||||||
|
CodeCommitProvider._get_edit_type(item.change_type),
|
||||||
|
)
|
||||||
|
)
|
||||||
return self.git_files
|
return self.git_files
|
||||||
|
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
@ -121,21 +129,28 @@ class CodeCommitProvider(GitProvider):
|
|||||||
if diff_item.a_blob_id is not None:
|
if diff_item.a_blob_id is not None:
|
||||||
patch_filename = diff_item.a_path
|
patch_filename = diff_item.a_path
|
||||||
original_file_content_str = self.codecommit_client.get_file(
|
original_file_content_str = self.codecommit_client.get_file(
|
||||||
self.repo_name, diff_item.a_path, self.pr.destination_commit)
|
self.repo_name, diff_item.a_path, self.pr.destination_commit
|
||||||
|
)
|
||||||
if isinstance(original_file_content_str, (bytes, bytearray)):
|
if isinstance(original_file_content_str, (bytes, bytearray)):
|
||||||
original_file_content_str = original_file_content_str.decode("utf-8")
|
original_file_content_str = original_file_content_str.decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
|
|
||||||
if diff_item.b_blob_id is not None:
|
if diff_item.b_blob_id is not None:
|
||||||
patch_filename = diff_item.b_path
|
patch_filename = diff_item.b_path
|
||||||
new_file_content_str = self.codecommit_client.get_file(self.repo_name, diff_item.b_path, self.pr.source_commit)
|
new_file_content_str = self.codecommit_client.get_file(
|
||||||
|
self.repo_name, diff_item.b_path, self.pr.source_commit
|
||||||
|
)
|
||||||
if isinstance(new_file_content_str, (bytes, bytearray)):
|
if isinstance(new_file_content_str, (bytes, bytearray)):
|
||||||
new_file_content_str = new_file_content_str.decode("utf-8")
|
new_file_content_str = new_file_content_str.decode("utf-8")
|
||||||
else:
|
else:
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
|
|
||||||
patch = load_large_diff(patch_filename, new_file_content_str, original_file_content_str)
|
patch = load_large_diff(
|
||||||
|
patch_filename, new_file_content_str, original_file_content_str
|
||||||
|
)
|
||||||
|
|
||||||
# Store the diffs as a list of FilePatchInfo objects
|
# Store the diffs as a list of FilePatchInfo objects
|
||||||
info = FilePatchInfo(
|
info = FilePatchInfo(
|
||||||
@ -164,7 +179,9 @@ class CodeCommitProvider(GitProvider):
|
|||||||
pr_body=CodeCommitProvider._add_additional_newlines(pr_body),
|
pr_body=CodeCommitProvider._add_additional_newlines(pr_body),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"CodeCommit Cannot publish description for PR: {self.pr_num}") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit Cannot publish description for PR: {self.pr_num}"
|
||||||
|
) from e
|
||||||
|
|
||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
if is_temporary:
|
if is_temporary:
|
||||||
@ -183,19 +200,28 @@ class CodeCommitProvider(GitProvider):
|
|||||||
comment=pr_comment,
|
comment=pr_comment,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"CodeCommit Cannot publish comment for PR: {self.pr_num}") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit Cannot publish comment for PR: {self.pr_num}"
|
||||||
|
) from e
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
counter = 1
|
counter = 1
|
||||||
for suggestion in code_suggestions:
|
for suggestion in code_suggestions:
|
||||||
# Verify that each suggestion has the required keys
|
# Verify that each suggestion has the required keys
|
||||||
if not all(key in suggestion for key in ["body", "relevant_file", "relevant_lines_start"]):
|
if not all(
|
||||||
get_logger().warning(f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys")
|
key in suggestion
|
||||||
|
for key in ["body", "relevant_file", "relevant_lines_start"]
|
||||||
|
):
|
||||||
|
get_logger().warning(
|
||||||
|
f"Skipping code suggestion #{counter}: Each suggestion must have 'body', 'relevant_file', 'relevant_lines_start' keys"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Publish the code suggestion to CodeCommit
|
# Publish the code suggestion to CodeCommit
|
||||||
try:
|
try:
|
||||||
get_logger().debug(f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}")
|
get_logger().debug(
|
||||||
|
f"Code Suggestion #{counter} in file: {suggestion['relevant_file']}: {suggestion['relevant_lines_start']}"
|
||||||
|
)
|
||||||
self.codecommit_client.publish_comment(
|
self.codecommit_client.publish_comment(
|
||||||
repo_name=self.repo_name,
|
repo_name=self.repo_name,
|
||||||
pr_number=self.pr_num,
|
pr_number=self.pr_num,
|
||||||
@ -206,7 +232,9 @@ class CodeCommitProvider(GitProvider):
|
|||||||
annotation_line=suggestion["relevant_lines_start"],
|
annotation_line=suggestion["relevant_lines_start"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}") from e
|
raise ValueError(
|
||||||
|
f"CodeCommit Cannot publish code suggestions for PR: {self.pr_num}"
|
||||||
|
) from e
|
||||||
|
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
@ -227,12 +255,22 @@ class CodeCommitProvider(GitProvider):
|
|||||||
def remove_comment(self, comment):
|
def remove_comment(self, comment):
|
||||||
return "" # not implemented yet
|
return "" # not implemented yet
|
||||||
|
|
||||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
original_suggestion=None,
|
||||||
|
):
|
||||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/codecommit/client/post_comment_for_compared_commit.html
|
||||||
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
|
raise NotImplementedError(
|
||||||
|
"CodeCommit provider does not support publishing inline comments yet"
|
||||||
|
)
|
||||||
|
|
||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
raise NotImplementedError("CodeCommit provider does not support publishing inline comments yet")
|
raise NotImplementedError(
|
||||||
|
"CodeCommit provider does not support publishing inline comments yet"
|
||||||
|
)
|
||||||
|
|
||||||
def get_title(self):
|
def get_title(self):
|
||||||
return self.pr.title
|
return self.pr.title
|
||||||
@ -257,7 +295,7 @@ class CodeCommitProvider(GitProvider):
|
|||||||
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
|
- dict: A dictionary where each key is a language name and the corresponding value is the percentage of that language in the PR.
|
||||||
"""
|
"""
|
||||||
commit_files = self.get_files()
|
commit_files = self.get_files()
|
||||||
filenames = [ item.filename for item in commit_files ]
|
filenames = [item.filename for item in commit_files]
|
||||||
extensions = CodeCommitProvider._get_file_extensions(filenames)
|
extensions = CodeCommitProvider._get_file_extensions(filenames)
|
||||||
|
|
||||||
# Calculate the percentage of each file extension in the PR
|
# Calculate the percentage of each file extension in the PR
|
||||||
@ -270,7 +308,9 @@ class CodeCommitProvider(GitProvider):
|
|||||||
# We build that language->extension dictionary here in main_extensions_flat.
|
# We build that language->extension dictionary here in main_extensions_flat.
|
||||||
main_extensions_flat = {}
|
main_extensions_flat = {}
|
||||||
language_extension_map_org = get_settings().language_extension_map_org
|
language_extension_map_org = get_settings().language_extension_map_org
|
||||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
language_extension_map = {
|
||||||
|
k.lower(): v for k, v in language_extension_map_org.items()
|
||||||
|
}
|
||||||
for language, extensions in language_extension_map.items():
|
for language, extensions in language_extension_map.items():
|
||||||
for ext in extensions:
|
for ext in extensions:
|
||||||
main_extensions_flat[ext] = language
|
main_extensions_flat[ext] = language
|
||||||
@ -292,14 +332,20 @@ class CodeCommitProvider(GitProvider):
|
|||||||
return -1 # not implemented yet
|
return -1 # not implemented yet
|
||||||
|
|
||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
raise NotImplementedError("CodeCommit provider does not support issue comments yet")
|
raise NotImplementedError(
|
||||||
|
"CodeCommit provider does not support issue comments yet"
|
||||||
|
)
|
||||||
|
|
||||||
def get_repo_settings(self):
|
def get_repo_settings(self):
|
||||||
# a local ".pr_agent.toml" settings file is optional
|
# a local ".pr_agent.toml" settings file is optional
|
||||||
settings_filename = ".pr_agent.toml"
|
settings_filename = ".pr_agent.toml"
|
||||||
return self.codecommit_client.get_file(self.repo_name, settings_filename, self.pr.source_commit, optional=True)
|
return self.codecommit_client.get_file(
|
||||||
|
self.repo_name, settings_filename, self.pr.source_commit, optional=True
|
||||||
|
)
|
||||||
|
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
def add_eyes_reaction(
|
||||||
|
self, issue_comment_id: int, disable_eyes: bool = False
|
||||||
|
) -> Optional[int]:
|
||||||
get_logger().info("CodeCommit provider does not support eyes reaction yet")
|
get_logger().info("CodeCommit provider does not support eyes reaction yet")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -323,7 +369,9 @@ class CodeCommitProvider(GitProvider):
|
|||||||
parsed_url = urlparse(pr_url)
|
parsed_url = urlparse(pr_url)
|
||||||
|
|
||||||
if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc):
|
if not CodeCommitProvider._is_valid_codecommit_hostname(parsed_url.netloc):
|
||||||
raise ValueError(f"The provided URL is not a valid CodeCommit URL: {pr_url}")
|
raise ValueError(
|
||||||
|
f"The provided URL is not a valid CodeCommit URL: {pr_url}"
|
||||||
|
)
|
||||||
|
|
||||||
path_parts = parsed_url.path.strip("/").split("/")
|
path_parts = parsed_url.path.strip("/").split("/")
|
||||||
|
|
||||||
@ -334,14 +382,18 @@ class CodeCommitProvider(GitProvider):
|
|||||||
or path_parts[2] != "repositories"
|
or path_parts[2] != "repositories"
|
||||||
or path_parts[4] != "pull-requests"
|
or path_parts[4] != "pull-requests"
|
||||||
):
|
):
|
||||||
raise ValueError(f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}")
|
raise ValueError(
|
||||||
|
f"The provided URL does not appear to be a CodeCommit PR URL: {pr_url}"
|
||||||
|
)
|
||||||
|
|
||||||
repo_name = path_parts[3]
|
repo_name = path_parts[3]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pr_number = int(path_parts[5])
|
pr_number = int(path_parts[5])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(f"Unable to convert PR number to integer: '{path_parts[5]}'") from e
|
raise ValueError(
|
||||||
|
f"Unable to convert PR number to integer: '{path_parts[5]}'"
|
||||||
|
) from e
|
||||||
|
|
||||||
return repo_name, pr_number
|
return repo_name, pr_number
|
||||||
|
|
||||||
@ -359,7 +411,12 @@ class CodeCommitProvider(GitProvider):
|
|||||||
Returns:
|
Returns:
|
||||||
- bool: True if the hostname is valid, False otherwise.
|
- bool: True if the hostname is valid, False otherwise.
|
||||||
"""
|
"""
|
||||||
return re.match(r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname) is not None
|
return (
|
||||||
|
re.match(
|
||||||
|
r"^[a-z]{2}-(gov-)?[a-z]+-\d\.console\.aws\.amazon\.com$", hostname
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
def _get_pr(self):
|
def _get_pr(self):
|
||||||
response = self.codecommit_client.get_pr(self.repo_name, self.pr_num)
|
response = self.codecommit_client.get_pr(self.repo_name, self.pr_num)
|
||||||
|
|||||||
@ -38,10 +38,7 @@ def clone(url, directory):
|
|||||||
|
|
||||||
def fetch(url, refspec, cwd):
|
def fetch(url, refspec, cwd):
|
||||||
get_logger().info("Fetching %s %s", url, refspec)
|
get_logger().info("Fetching %s %s", url, refspec)
|
||||||
stdout = _call(
|
stdout = _call('git', 'fetch', '--depth', '2', url, refspec, cwd=cwd)
|
||||||
'git', 'fetch', '--depth', '2', url, refspec,
|
|
||||||
cwd=cwd
|
|
||||||
)
|
|
||||||
get_logger().info(stdout)
|
get_logger().info(stdout)
|
||||||
|
|
||||||
|
|
||||||
@ -75,10 +72,13 @@ def add_comment(url: urllib3.util.Url, refspec, message):
|
|||||||
message = "'" + message.replace("'", "'\"'\"'") + "'"
|
message = "'" + message.replace("'", "'\"'\"'") + "'"
|
||||||
return _call(
|
return _call(
|
||||||
"ssh",
|
"ssh",
|
||||||
"-p", str(url.port),
|
"-p",
|
||||||
|
str(url.port),
|
||||||
f"{url.auth}@{url.host}",
|
f"{url.auth}@{url.host}",
|
||||||
"gerrit", "review",
|
"gerrit",
|
||||||
"--message", message,
|
"review",
|
||||||
|
"--message",
|
||||||
|
message,
|
||||||
# "--code-review", score,
|
# "--code-review", score,
|
||||||
f"{patchset},{changenum}",
|
f"{patchset},{changenum}",
|
||||||
)
|
)
|
||||||
@ -88,19 +88,23 @@ def list_comments(url: urllib3.util.Url, refspec):
|
|||||||
*_, patchset, _ = refspec.rsplit("/")
|
*_, patchset, _ = refspec.rsplit("/")
|
||||||
stdout = _call(
|
stdout = _call(
|
||||||
"ssh",
|
"ssh",
|
||||||
"-p", str(url.port),
|
"-p",
|
||||||
|
str(url.port),
|
||||||
f"{url.auth}@{url.host}",
|
f"{url.auth}@{url.host}",
|
||||||
"gerrit", "query",
|
"gerrit",
|
||||||
|
"query",
|
||||||
"--comments",
|
"--comments",
|
||||||
"--current-patch-set", patchset,
|
"--current-patch-set",
|
||||||
"--format", "JSON",
|
patchset,
|
||||||
|
"--format",
|
||||||
|
"JSON",
|
||||||
)
|
)
|
||||||
change_set, *_ = stdout.splitlines()
|
change_set, *_ = stdout.splitlines()
|
||||||
return json.loads(change_set)["currentPatchSet"]["comments"]
|
return json.loads(change_set)["currentPatchSet"]["comments"]
|
||||||
|
|
||||||
|
|
||||||
def prepare_repo(url: urllib3.util.Url, project, refspec):
|
def prepare_repo(url: urllib3.util.Url, project, refspec):
|
||||||
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}")
|
repo_url = f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}"
|
||||||
|
|
||||||
directory = pathlib.Path(mkdtemp())
|
directory = pathlib.Path(mkdtemp())
|
||||||
clone(repo_url, directory),
|
clone(repo_url, directory),
|
||||||
@ -114,18 +118,18 @@ def adopt_to_gerrit_message(message):
|
|||||||
buf = []
|
buf = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
# remove markdown formatting
|
# remove markdown formatting
|
||||||
line = (line.replace("*", "")
|
line = (
|
||||||
.replace("``", "`")
|
line.replace("*", "")
|
||||||
.replace("<details>", "")
|
.replace("``", "`")
|
||||||
.replace("</details>", "")
|
.replace("<details>", "")
|
||||||
.replace("<summary>", "")
|
.replace("</details>", "")
|
||||||
.replace("</summary>", ""))
|
.replace("<summary>", "")
|
||||||
|
.replace("</summary>", "")
|
||||||
|
)
|
||||||
|
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line.startswith('#'):
|
if line.startswith('#'):
|
||||||
buf.append("\n" +
|
buf.append("\n" + line.replace('#', '').removesuffix(":").strip() + ":")
|
||||||
line.replace('#', '').removesuffix(":").strip() +
|
|
||||||
":")
|
|
||||||
continue
|
continue
|
||||||
elif line.startswith('-'):
|
elif line.startswith('-'):
|
||||||
buf.append(line.removeprefix('-').strip())
|
buf.append(line.removeprefix('-').strip())
|
||||||
@ -136,12 +140,9 @@ def adopt_to_gerrit_message(message):
|
|||||||
|
|
||||||
|
|
||||||
def add_suggestion(src_filename, context: str, start, end: int):
|
def add_suggestion(src_filename, context: str, start, end: int):
|
||||||
with (
|
with NamedTemporaryFile("w", delete=False) as tmp, open(src_filename, "r") as src:
|
||||||
NamedTemporaryFile("w", delete=False) as tmp,
|
|
||||||
open(src_filename, "r") as src
|
|
||||||
):
|
|
||||||
lines = src.readlines()
|
lines = src.readlines()
|
||||||
tmp.writelines(lines[:start - 1])
|
tmp.writelines(lines[: start - 1])
|
||||||
if context:
|
if context:
|
||||||
tmp.write(context)
|
tmp.write(context)
|
||||||
tmp.writelines(lines[end:])
|
tmp.writelines(lines[end:])
|
||||||
@ -151,10 +152,8 @@ def add_suggestion(src_filename, context: str, start, end: int):
|
|||||||
|
|
||||||
|
|
||||||
def upload_patch(patch, path):
|
def upload_patch(patch, path):
|
||||||
patch_server_endpoint = get_settings().get(
|
patch_server_endpoint = get_settings().get('gerrit.patch_server_endpoint')
|
||||||
'gerrit.patch_server_endpoint')
|
patch_server_token = get_settings().get('gerrit.patch_server_token')
|
||||||
patch_server_token = get_settings().get(
|
|
||||||
'gerrit.patch_server_token')
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
patch_server_endpoint,
|
patch_server_endpoint,
|
||||||
@ -165,7 +164,7 @@ def upload_patch(patch, path):
|
|||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {patch_server_token}",
|
"Authorization": f"Bearer {patch_server_token}",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
patch_server_endpoint = patch_server_endpoint.rstrip("/")
|
patch_server_endpoint = patch_server_endpoint.rstrip("/")
|
||||||
@ -173,7 +172,6 @@ def upload_patch(patch, path):
|
|||||||
|
|
||||||
|
|
||||||
class GerritProvider(GitProvider):
|
class GerritProvider(GitProvider):
|
||||||
|
|
||||||
def __init__(self, key: str, incremental=False):
|
def __init__(self, key: str, incremental=False):
|
||||||
self.project, self.refspec = key.split(':')
|
self.project, self.refspec = key.split(':')
|
||||||
assert self.project, "Project name is required"
|
assert self.project, "Project name is required"
|
||||||
@ -188,9 +186,7 @@ class GerritProvider(GitProvider):
|
|||||||
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
|
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.repo_path = prepare_repo(
|
self.repo_path = prepare_repo(self.parsed_url, self.project, self.refspec)
|
||||||
self.parsed_url, self.project, self.refspec
|
|
||||||
)
|
|
||||||
self.repo = Repo(self.repo_path)
|
self.repo = Repo(self.repo_path)
|
||||||
assert self.repo
|
assert self.repo
|
||||||
self.pr_url = base_url
|
self.pr_url = base_url
|
||||||
@ -210,15 +206,18 @@ class GerritProvider(GitProvider):
|
|||||||
|
|
||||||
def get_pr_labels(self, update=False):
|
def get_pr_labels(self, update=False):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Getting labels is not implemented for the gerrit provider')
|
'Getting labels is not implemented for the gerrit provider'
|
||||||
|
)
|
||||||
|
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
|
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Adding reactions is not implemented for the gerrit provider')
|
'Adding reactions is not implemented for the gerrit provider'
|
||||||
|
)
|
||||||
|
|
||||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
|
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Removing reactions is not implemented for the gerrit provider')
|
'Removing reactions is not implemented for the gerrit provider'
|
||||||
|
)
|
||||||
|
|
||||||
def get_commit_messages(self):
|
def get_commit_messages(self):
|
||||||
return [self.repo.head.commit.message]
|
return [self.repo.head.commit.message]
|
||||||
@ -235,20 +234,21 @@ class GerritProvider(GitProvider):
|
|||||||
diffs = self.repo.head.commit.diff(
|
diffs = self.repo.head.commit.diff(
|
||||||
self.repo.head.commit.parents[0], # previous commit
|
self.repo.head.commit.parents[0], # previous commit
|
||||||
create_patch=True,
|
create_patch=True,
|
||||||
R=True
|
R=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
diff_files = []
|
diff_files = []
|
||||||
for diff_item in diffs:
|
for diff_item in diffs:
|
||||||
if diff_item.a_blob is not None:
|
if diff_item.a_blob is not None:
|
||||||
original_file_content_str = (
|
original_file_content_str = diff_item.a_blob.data_stream.read().decode(
|
||||||
diff_item.a_blob.data_stream.read().decode('utf-8')
|
'utf-8'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
original_file_content_str = "" # empty file
|
original_file_content_str = "" # empty file
|
||||||
if diff_item.b_blob is not None:
|
if diff_item.b_blob is not None:
|
||||||
new_file_content_str = diff_item.b_blob.data_stream.read(). \
|
new_file_content_str = diff_item.b_blob.data_stream.read().decode(
|
||||||
decode('utf-8')
|
'utf-8'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new_file_content_str = "" # empty file
|
new_file_content_str = "" # empty file
|
||||||
edit_type = EDIT_TYPE.MODIFIED
|
edit_type = EDIT_TYPE.MODIFIED
|
||||||
@ -267,7 +267,7 @@ class GerritProvider(GitProvider):
|
|||||||
edit_type=edit_type,
|
edit_type=edit_type,
|
||||||
old_filename=None
|
old_filename=None
|
||||||
if diff_item.a_path == diff_item.b_path
|
if diff_item.a_path == diff_item.b_path
|
||||||
else diff_item.a_path
|
else diff_item.a_path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.diff_files = diff_files
|
self.diff_files = diff_files
|
||||||
@ -275,8 +275,7 @@ class GerritProvider(GitProvider):
|
|||||||
|
|
||||||
def get_files(self):
|
def get_files(self):
|
||||||
diff_index = self.repo.head.commit.diff(
|
diff_index = self.repo.head.commit.diff(
|
||||||
self.repo.head.commit.parents[0], # previous commit
|
self.repo.head.commit.parents[0], R=True # previous commit
|
||||||
R=True
|
|
||||||
)
|
)
|
||||||
# Get the list of changed files
|
# Get the list of changed files
|
||||||
diff_files = [item.a_path for item in diff_index]
|
diff_files = [item.a_path for item in diff_index]
|
||||||
@ -288,16 +287,22 @@ class GerritProvider(GitProvider):
|
|||||||
prioritisation.
|
prioritisation.
|
||||||
"""
|
"""
|
||||||
# Get all files in repository
|
# Get all files in repository
|
||||||
filepaths = [Path(item.path) for item in
|
filepaths = [
|
||||||
self.repo.tree().traverse() if item.type == 'blob']
|
Path(item.path)
|
||||||
|
for item in self.repo.tree().traverse()
|
||||||
|
if item.type == 'blob'
|
||||||
|
]
|
||||||
# Identify language by file extension and count
|
# Identify language by file extension and count
|
||||||
lang_count = Counter(
|
lang_count = Counter(
|
||||||
ext.lstrip('.') for filepath in filepaths for ext in
|
ext.lstrip('.')
|
||||||
[filepath.suffix.lower()])
|
for filepath in filepaths
|
||||||
|
for ext in [filepath.suffix.lower()]
|
||||||
|
)
|
||||||
# Convert counts to percentages
|
# Convert counts to percentages
|
||||||
total_files = len(filepaths)
|
total_files = len(filepaths)
|
||||||
lang_percentage = {lang: count / total_files * 100 for lang, count
|
lang_percentage = {
|
||||||
in lang_count.items()}
|
lang: count / total_files * 100 for lang, count in lang_count.items()
|
||||||
|
}
|
||||||
return lang_percentage
|
return lang_percentage
|
||||||
|
|
||||||
def get_pr_description_full(self):
|
def get_pr_description_full(self):
|
||||||
@ -312,7 +317,7 @@ class GerritProvider(GitProvider):
|
|||||||
'create_inline_comment',
|
'create_inline_comment',
|
||||||
'publish_inline_comments',
|
'publish_inline_comments',
|
||||||
'get_labels',
|
'get_labels',
|
||||||
'gfm_markdown'
|
'gfm_markdown',
|
||||||
]:
|
]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@ -331,14 +336,9 @@ class GerritProvider(GitProvider):
|
|||||||
if is_code_context:
|
if is_code_context:
|
||||||
context.append(line)
|
context.append(line)
|
||||||
else:
|
else:
|
||||||
description.append(
|
description.append(line.replace('*', ''))
|
||||||
line.replace('*', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return ('\n'.join(description), '\n'.join(context) + '\n' if context else '')
|
||||||
'\n'.join(description),
|
|
||||||
'\n'.join(context) + '\n' if context else ''
|
|
||||||
)
|
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list):
|
def publish_code_suggestions(self, code_suggestions: list):
|
||||||
msg = []
|
msg = []
|
||||||
@ -372,15 +372,19 @@ class GerritProvider(GitProvider):
|
|||||||
|
|
||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Publishing inline comments is not implemented for the gerrit '
|
'Publishing inline comments is not implemented for the gerrit ' 'provider'
|
||||||
'provider')
|
)
|
||||||
|
|
||||||
def publish_inline_comment(self, body: str, relevant_file: str,
|
def publish_inline_comment(
|
||||||
relevant_line_in_file: str, original_suggestion=None):
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
original_suggestion=None,
|
||||||
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Publishing inline comments is not implemented for the gerrit '
|
'Publishing inline comments is not implemented for the gerrit ' 'provider'
|
||||||
'provider')
|
)
|
||||||
|
|
||||||
|
|
||||||
def publish_labels(self, labels):
|
def publish_labels(self, labels):
|
||||||
# Not applicable to the local git provider,
|
# Not applicable to the local git provider,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ from utils.pr_agent.log import get_logger
|
|||||||
|
|
||||||
MAX_FILES_ALLOWED_FULL = 50
|
MAX_FILES_ALLOWED_FULL = 50
|
||||||
|
|
||||||
|
|
||||||
class GitProvider(ABC):
|
class GitProvider(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_supported(self, capability: str) -> bool:
|
def is_supported(self, capability: str) -> bool:
|
||||||
@ -61,11 +63,18 @@ class GitProvider(ABC):
|
|||||||
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
|
def reply_to_comment_from_comment_id(self, comment_id: int, body: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_pr_description(self, full: bool = True, split_changes_walkthrough=False) -> str or tuple:
|
def get_pr_description(
|
||||||
|
self, full: bool = True, split_changes_walkthrough=False
|
||||||
|
) -> str or tuple:
|
||||||
from utils.pr_agent.algo.utils import clip_tokens
|
from utils.pr_agent.algo.utils import clip_tokens
|
||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
max_tokens_description = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
|
||||||
description = self.get_pr_description_full() if full else self.get_user_description()
|
max_tokens_description = get_settings().get(
|
||||||
|
"CONFIG.MAX_DESCRIPTION_TOKENS", None
|
||||||
|
)
|
||||||
|
description = (
|
||||||
|
self.get_pr_description_full() if full else self.get_user_description()
|
||||||
|
)
|
||||||
if split_changes_walkthrough:
|
if split_changes_walkthrough:
|
||||||
description, files = process_description(description)
|
description, files = process_description(description)
|
||||||
if max_tokens_description:
|
if max_tokens_description:
|
||||||
@ -94,7 +103,9 @@ class GitProvider(ABC):
|
|||||||
# return nothing (empty string) because it means there is no user description
|
# return nothing (empty string) because it means there is no user description
|
||||||
user_description_header = "### **user description**"
|
user_description_header = "### **user description**"
|
||||||
if user_description_header not in description_lowercase:
|
if user_description_header not in description_lowercase:
|
||||||
get_logger().info(f"Existing description was generated by the pr-agent, but it doesn't contain a user description")
|
get_logger().info(
|
||||||
|
f"Existing description was generated by the pr-agent, but it doesn't contain a user description"
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# otherwise, extract the original user description from the existing pr-agent description and return it
|
# otherwise, extract the original user description from the existing pr-agent description and return it
|
||||||
@ -103,9 +114,11 @@ class GitProvider(ABC):
|
|||||||
|
|
||||||
# the 'user description' is in the beginning. extract and return it
|
# the 'user description' is in the beginning. extract and return it
|
||||||
possible_headers = self._possible_headers()
|
possible_headers = self._possible_headers()
|
||||||
start_position = description_lowercase.find(user_description_header) + len(user_description_header)
|
start_position = description_lowercase.find(user_description_header) + len(
|
||||||
|
user_description_header
|
||||||
|
)
|
||||||
end_position = len(description)
|
end_position = len(description)
|
||||||
for header in possible_headers: # try to clip at the next header
|
for header in possible_headers: # try to clip at the next header
|
||||||
if header != user_description_header and header in description_lowercase:
|
if header != user_description_header and header in description_lowercase:
|
||||||
end_position = min(end_position, description_lowercase.find(header))
|
end_position = min(end_position, description_lowercase.find(header))
|
||||||
if end_position != len(description) and end_position > start_position:
|
if end_position != len(description) and end_position > start_position:
|
||||||
@ -115,20 +128,34 @@ class GitProvider(ABC):
|
|||||||
else:
|
else:
|
||||||
original_user_description = description.split("___")[0].strip()
|
original_user_description = description.split("___")[0].strip()
|
||||||
if original_user_description.lower().startswith(user_description_header):
|
if original_user_description.lower().startswith(user_description_header):
|
||||||
original_user_description = original_user_description[len(user_description_header):].strip()
|
original_user_description = original_user_description[
|
||||||
|
len(user_description_header) :
|
||||||
|
].strip()
|
||||||
|
|
||||||
get_logger().info(f"Extracted user description from existing description",
|
get_logger().info(
|
||||||
description=original_user_description)
|
f"Extracted user description from existing description",
|
||||||
|
description=original_user_description,
|
||||||
|
)
|
||||||
self.user_description = original_user_description
|
self.user_description = original_user_description
|
||||||
return original_user_description
|
return original_user_description
|
||||||
|
|
||||||
def _possible_headers(self):
|
def _possible_headers(self):
|
||||||
return ("### **user description**", "### **pr type**", "### **pr description**", "### **pr labels**", "### **type**", "### **description**",
|
return (
|
||||||
"### **labels**", "### 🤖 generated by pr agent")
|
"### **user description**",
|
||||||
|
"### **pr type**",
|
||||||
|
"### **pr description**",
|
||||||
|
"### **pr labels**",
|
||||||
|
"### **type**",
|
||||||
|
"### **description**",
|
||||||
|
"### **labels**",
|
||||||
|
"### 🤖 generated by pr agent",
|
||||||
|
)
|
||||||
|
|
||||||
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
|
def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool:
|
||||||
possible_headers = self._possible_headers()
|
possible_headers = self._possible_headers()
|
||||||
return any(description_lowercase.startswith(header) for header in possible_headers)
|
return any(
|
||||||
|
description_lowercase.startswith(header) for header in possible_headers
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_repo_settings(self):
|
def get_repo_settings(self):
|
||||||
@ -140,10 +167,17 @@ class GitProvider(ABC):
|
|||||||
def get_pr_id(self):
|
def get_pr_id(self):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
def get_line_link(
|
||||||
|
self,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_start: int,
|
||||||
|
relevant_line_end: int = None,
|
||||||
|
) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_lines_link_original_file(self, filepath:str, component_range: Range) -> str:
|
def get_lines_link_original_file(
|
||||||
|
self, filepath: str, component_range: Range
|
||||||
|
) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
#### comments operations ####
|
#### comments operations ####
|
||||||
@ -151,18 +185,24 @@ class GitProvider(ABC):
|
|||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def publish_persistent_comment(self, pr_comment: str,
|
def publish_persistent_comment(
|
||||||
initial_header: str,
|
self,
|
||||||
update_header: bool = True,
|
pr_comment: str,
|
||||||
name='review',
|
initial_header: str,
|
||||||
final_update_message=True):
|
update_header: bool = True,
|
||||||
|
name='review',
|
||||||
|
final_update_message=True,
|
||||||
|
):
|
||||||
self.publish_comment(pr_comment)
|
self.publish_comment(pr_comment)
|
||||||
|
|
||||||
def publish_persistent_comment_full(self, pr_comment: str,
|
def publish_persistent_comment_full(
|
||||||
initial_header: str,
|
self,
|
||||||
update_header: bool = True,
|
pr_comment: str,
|
||||||
name='review',
|
initial_header: str,
|
||||||
final_update_message=True):
|
update_header: bool = True,
|
||||||
|
name='review',
|
||||||
|
final_update_message=True,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
prev_comments = list(self.get_issue_comments())
|
prev_comments = list(self.get_issue_comments())
|
||||||
for comment in prev_comments:
|
for comment in prev_comments:
|
||||||
@ -171,29 +211,46 @@ class GitProvider(ABC):
|
|||||||
comment_url = self.get_comment_url(comment)
|
comment_url = self.get_comment_url(comment)
|
||||||
if update_header:
|
if update_header:
|
||||||
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
updated_header = f"{initial_header}\n\n#### ({name.capitalize()} updated until commit {latest_commit_url})\n"
|
||||||
pr_comment_updated = pr_comment.replace(initial_header, updated_header)
|
pr_comment_updated = pr_comment.replace(
|
||||||
|
initial_header, updated_header
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pr_comment_updated = pr_comment
|
pr_comment_updated = pr_comment
|
||||||
get_logger().info(f"Persistent mode - updating comment {comment_url} to latest {name} message")
|
get_logger().info(
|
||||||
|
f"Persistent mode - updating comment {comment_url} to latest {name} message"
|
||||||
|
)
|
||||||
# response = self.mr.notes.update(comment.id, {'body': pr_comment_updated})
|
# response = self.mr.notes.update(comment.id, {'body': pr_comment_updated})
|
||||||
self.edit_comment(comment, pr_comment_updated)
|
self.edit_comment(comment, pr_comment_updated)
|
||||||
if final_update_message:
|
if final_update_message:
|
||||||
self.publish_comment(
|
self.publish_comment(
|
||||||
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}")
|
f"**[Persistent {name}]({comment_url})** updated to latest commit {latest_commit_url}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
get_logger().exception(f"Failed to update persistent review, error: {e}")
|
||||||
pass
|
pass
|
||||||
self.publish_comment(pr_comment)
|
self.publish_comment(pr_comment)
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
original_suggestion=None,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
def create_inline_comment(
|
||||||
absolute_position: int = None):
|
self,
|
||||||
raise NotImplementedError("This git provider does not support creating inline comments yet")
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
absolute_position: int = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This git provider does not support creating inline comments yet"
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
@ -227,7 +284,9 @@ class GitProvider(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
def add_eyes_reaction(
|
||||||
|
self, issue_comment_id: int, disable_eyes: bool = False
|
||||||
|
) -> Optional[int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -284,16 +343,23 @@ def get_main_pr_language(languages, files) -> str:
|
|||||||
if not file:
|
if not file:
|
||||||
continue
|
continue
|
||||||
if isinstance(file, str):
|
if isinstance(file, str):
|
||||||
file = FilePatchInfo(base_file=None, head_file=None, patch=None, filename=file)
|
file = FilePatchInfo(
|
||||||
|
base_file=None, head_file=None, patch=None, filename=file
|
||||||
|
)
|
||||||
extension_list.append(file.filename.rsplit('.')[-1])
|
extension_list.append(file.filename.rsplit('.')[-1])
|
||||||
|
|
||||||
# get the most common extension
|
# get the most common extension
|
||||||
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
|
most_common_extension = '.' + max(set(extension_list), key=extension_list.count)
|
||||||
try:
|
try:
|
||||||
language_extension_map_org = get_settings().language_extension_map_org
|
language_extension_map_org = get_settings().language_extension_map_org
|
||||||
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
|
language_extension_map = {
|
||||||
|
k.lower(): v for k, v in language_extension_map_org.items()
|
||||||
|
}
|
||||||
|
|
||||||
if top_language in language_extension_map and most_common_extension in language_extension_map[top_language]:
|
if (
|
||||||
|
top_language in language_extension_map
|
||||||
|
and most_common_extension in language_extension_map[top_language]
|
||||||
|
):
|
||||||
main_language_str = top_language
|
main_language_str = top_language
|
||||||
else:
|
else:
|
||||||
for language, extensions in language_extension_map.items():
|
for language, extensions in language_extension_map.items():
|
||||||
@ -332,8 +398,6 @@ def get_main_pr_language(languages, files) -> str:
|
|||||||
return main_language_str
|
return main_language_str
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class IncrementalPR:
|
class IncrementalPR:
|
||||||
def __init__(self, is_incremental: bool = False):
|
def __init__(self, is_incremental: bool = False):
|
||||||
self.is_incremental = is_incremental
|
self.is_incremental = is_incremental
|
||||||
|
|||||||
@ -18,14 +18,23 @@ from ..algo.file_filter import filter_ignored
|
|||||||
from ..algo.git_patch_processing import extract_hunk_headers
|
from ..algo.git_patch_processing import extract_hunk_headers
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
from ..algo.types import EDIT_TYPE
|
from ..algo.types import EDIT_TYPE
|
||||||
from ..algo.utils import (PRReviewHeader, Range, clip_tokens,
|
from ..algo.utils import (
|
||||||
find_line_number_of_relevant_line_in_file,
|
PRReviewHeader,
|
||||||
load_large_diff, set_file_languages)
|
Range,
|
||||||
|
clip_tokens,
|
||||||
|
find_line_number_of_relevant_line_in_file,
|
||||||
|
load_large_diff,
|
||||||
|
set_file_languages,
|
||||||
|
)
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from ..log import get_logger
|
from ..log import get_logger
|
||||||
from ..servers.utils import RateLimitExceeded
|
from ..servers.utils import RateLimitExceeded
|
||||||
from .git_provider import (MAX_FILES_ALLOWED_FULL, FilePatchInfo, GitProvider,
|
from .git_provider import (
|
||||||
IncrementalPR)
|
MAX_FILES_ALLOWED_FULL,
|
||||||
|
FilePatchInfo,
|
||||||
|
GitProvider,
|
||||||
|
IncrementalPR,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GithubProvider(GitProvider):
|
class GithubProvider(GitProvider):
|
||||||
@ -36,8 +45,14 @@ class GithubProvider(GitProvider):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.installation_id = None
|
self.installation_id = None
|
||||||
self.max_comment_chars = 65000
|
self.max_comment_chars = 65000
|
||||||
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
|
self.base_url = (
|
||||||
self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com"
|
get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/")
|
||||||
|
) # "https://api.github.com"
|
||||||
|
self.base_url_html = (
|
||||||
|
self.base_url.split("api/")[0].rstrip("/")
|
||||||
|
if "api/" in self.base_url
|
||||||
|
else "https://github.com"
|
||||||
|
)
|
||||||
self.github_client = self._get_github_client()
|
self.github_client = self._get_github_client()
|
||||||
self.repo = None
|
self.repo = None
|
||||||
self.pr_num = None
|
self.pr_num = None
|
||||||
@ -50,7 +65,9 @@ class GithubProvider(GitProvider):
|
|||||||
self.set_pr(pr_url)
|
self.set_pr(pr_url)
|
||||||
self.pr_commits = list(self.pr.get_commits())
|
self.pr_commits = list(self.pr.get_commits())
|
||||||
self.last_commit_id = self.pr_commits[-1]
|
self.last_commit_id = self.pr_commits[-1]
|
||||||
self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
|
self.pr_url = (
|
||||||
|
self.get_pr_url()
|
||||||
|
) # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
|
||||||
else:
|
else:
|
||||||
self.pr_commits = None
|
self.pr_commits = None
|
||||||
|
|
||||||
@ -80,10 +97,14 @@ class GithubProvider(GitProvider):
|
|||||||
# Get all files changed during the commit range
|
# Get all files changed during the commit range
|
||||||
|
|
||||||
for commit in self.incremental.commits_range:
|
for commit in self.incremental.commits_range:
|
||||||
if commit.commit.message.startswith(f"Merge branch '{self._get_repo().default_branch}'"):
|
if commit.commit.message.startswith(
|
||||||
|
f"Merge branch '{self._get_repo().default_branch}'"
|
||||||
|
):
|
||||||
get_logger().info(f"Skipping merge commit {commit.commit.message}")
|
get_logger().info(f"Skipping merge commit {commit.commit.message}")
|
||||||
continue
|
continue
|
||||||
self.unreviewed_files_set.update({file.filename: file for file in commit.files})
|
self.unreviewed_files_set.update(
|
||||||
|
{file.filename: file for file in commit.files}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().info("No previous review found, will review the entire PR")
|
get_logger().info("No previous review found, will review the entire PR")
|
||||||
self.incremental.is_incremental = False
|
self.incremental.is_incremental = False
|
||||||
@ -98,7 +119,11 @@ class GithubProvider(GitProvider):
|
|||||||
else:
|
else:
|
||||||
self.incremental.last_seen_commit = self.pr_commits[index]
|
self.incremental.last_seen_commit = self.pr_commits[index]
|
||||||
break
|
break
|
||||||
return self.pr_commits[first_new_commit_index:] if first_new_commit_index is not None else []
|
return (
|
||||||
|
self.pr_commits[first_new_commit_index:]
|
||||||
|
if first_new_commit_index is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
def get_previous_review(self, *, full: bool, incremental: bool):
|
def get_previous_review(self, *, full: bool, incremental: bool):
|
||||||
if not (full or incremental):
|
if not (full or incremental):
|
||||||
@ -121,7 +146,7 @@ class GithubProvider(GitProvider):
|
|||||||
git_files = context.get("git_files", None)
|
git_files = context.get("git_files", None)
|
||||||
if git_files:
|
if git_files:
|
||||||
return git_files
|
return git_files
|
||||||
self.git_files = list(self.pr.get_files()) # 'list' to handle pagination
|
self.git_files = list(self.pr.get_files()) # 'list' to handle pagination
|
||||||
context["git_files"] = self.git_files
|
context["git_files"] = self.git_files
|
||||||
return self.git_files
|
return self.git_files
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -138,8 +163,13 @@ class GithubProvider(GitProvider):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
@retry(exceptions=RateLimitExceeded,
|
@retry(
|
||||||
tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
|
exceptions=RateLimitExceeded,
|
||||||
|
tries=get_settings().github.ratelimit_retries,
|
||||||
|
delay=2,
|
||||||
|
backoff=2,
|
||||||
|
jitter=(1, 3),
|
||||||
|
)
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
"""
|
"""
|
||||||
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub,
|
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub,
|
||||||
@ -167,9 +197,10 @@ class GithubProvider(GitProvider):
|
|||||||
try:
|
try:
|
||||||
names_original = [file.filename for file in files_original]
|
names_original = [file.filename for file in files_original]
|
||||||
names_new = [file.filename for file in files]
|
names_new = [file.filename for file in files]
|
||||||
get_logger().info(f"Filtered out [ignore] files for pull request:", extra=
|
get_logger().info(
|
||||||
{"files": names_original,
|
f"Filtered out [ignore] files for pull request:",
|
||||||
"filtered_files": names_new})
|
extra={"files": names_original, "filtered_files": names_new},
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -184,14 +215,17 @@ class GithubProvider(GitProvider):
|
|||||||
repo = self.repo_obj
|
repo = self.repo_obj
|
||||||
pr = self.pr
|
pr = self.pr
|
||||||
try:
|
try:
|
||||||
compare = repo.compare(pr.base.sha, pr.head.sha) # communication with GitHub
|
compare = repo.compare(
|
||||||
|
pr.base.sha, pr.head.sha
|
||||||
|
) # communication with GitHub
|
||||||
merge_base_commit = compare.merge_base_commit
|
merge_base_commit = compare.merge_base_commit
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to get merge base commit: {e}")
|
get_logger().error(f"Failed to get merge base commit: {e}")
|
||||||
merge_base_commit = pr.base
|
merge_base_commit = pr.base
|
||||||
if merge_base_commit.sha != pr.base.sha:
|
if merge_base_commit.sha != pr.base.sha:
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Using merge base commit {merge_base_commit.sha} instead of base commit ")
|
f"Using merge base commit {merge_base_commit.sha} instead of base commit "
|
||||||
|
)
|
||||||
|
|
||||||
counter_valid = 0
|
counter_valid = 0
|
||||||
for file in files:
|
for file in files:
|
||||||
@ -207,29 +241,48 @@ class GithubProvider(GitProvider):
|
|||||||
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
|
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
|
||||||
counter_valid += 1
|
counter_valid += 1
|
||||||
avoid_load = False
|
avoid_load = False
|
||||||
if counter_valid >= MAX_FILES_ALLOWED_FULL and patch and not self.incremental.is_incremental:
|
if (
|
||||||
|
counter_valid >= MAX_FILES_ALLOWED_FULL
|
||||||
|
and patch
|
||||||
|
and not self.incremental.is_incremental
|
||||||
|
):
|
||||||
avoid_load = True
|
avoid_load = True
|
||||||
if counter_valid == MAX_FILES_ALLOWED_FULL:
|
if counter_valid == MAX_FILES_ALLOWED_FULL:
|
||||||
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
|
get_logger().info(
|
||||||
|
f"Too many files in PR, will avoid loading full content for rest of files"
|
||||||
|
)
|
||||||
|
|
||||||
if avoid_load:
|
if avoid_load:
|
||||||
new_file_content_str = ""
|
new_file_content_str = ""
|
||||||
else:
|
else:
|
||||||
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub
|
new_file_content_str = self._get_pr_file_content(
|
||||||
|
file, self.pr.head.sha
|
||||||
|
) # communication with GitHub
|
||||||
|
|
||||||
if self.incremental.is_incremental and self.unreviewed_files_set:
|
if self.incremental.is_incremental and self.unreviewed_files_set:
|
||||||
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
|
original_file_content_str = self._get_pr_file_content(
|
||||||
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
|
file, self.incremental.last_seen_commit_sha
|
||||||
|
)
|
||||||
|
patch = load_large_diff(
|
||||||
|
file.filename,
|
||||||
|
new_file_content_str,
|
||||||
|
original_file_content_str,
|
||||||
|
)
|
||||||
self.unreviewed_files_set[file.filename] = patch
|
self.unreviewed_files_set[file.filename] = patch
|
||||||
else:
|
else:
|
||||||
if avoid_load:
|
if avoid_load:
|
||||||
original_file_content_str = ""
|
original_file_content_str = ""
|
||||||
else:
|
else:
|
||||||
original_file_content_str = self._get_pr_file_content(file, merge_base_commit.sha)
|
original_file_content_str = self._get_pr_file_content(
|
||||||
|
file, merge_base_commit.sha
|
||||||
|
)
|
||||||
# original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
|
# original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
|
||||||
if not patch:
|
if not patch:
|
||||||
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
|
patch = load_large_diff(
|
||||||
|
file.filename,
|
||||||
|
new_file_content_str,
|
||||||
|
original_file_content_str,
|
||||||
|
)
|
||||||
|
|
||||||
if file.status == 'added':
|
if file.status == 'added':
|
||||||
edit_type = EDIT_TYPE.ADDED
|
edit_type = EDIT_TYPE.ADDED
|
||||||
@ -249,16 +302,27 @@ class GithubProvider(GitProvider):
|
|||||||
num_minus_lines = file.deletions
|
num_minus_lines = file.deletions
|
||||||
else:
|
else:
|
||||||
patch_lines = patch.splitlines(keepends=True)
|
patch_lines = patch.splitlines(keepends=True)
|
||||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
num_plus_lines = len(
|
||||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
[line for line in patch_lines if line.startswith('+')]
|
||||||
|
)
|
||||||
|
num_minus_lines = len(
|
||||||
|
[line for line in patch_lines if line.startswith('-')]
|
||||||
|
)
|
||||||
|
|
||||||
file_patch_canonical_structure = FilePatchInfo(original_file_content_str, new_file_content_str, patch,
|
file_patch_canonical_structure = FilePatchInfo(
|
||||||
file.filename, edit_type=edit_type,
|
original_file_content_str,
|
||||||
num_plus_lines=num_plus_lines,
|
new_file_content_str,
|
||||||
num_minus_lines=num_minus_lines,)
|
patch,
|
||||||
|
file.filename,
|
||||||
|
edit_type=edit_type,
|
||||||
|
num_plus_lines=num_plus_lines,
|
||||||
|
num_minus_lines=num_minus_lines,
|
||||||
|
)
|
||||||
diff_files.append(file_patch_canonical_structure)
|
diff_files.append(file_patch_canonical_structure)
|
||||||
if invalid_files_names:
|
if invalid_files_names:
|
||||||
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
|
get_logger().info(
|
||||||
|
f"Filtered out files with invalid extensions: {invalid_files_names}"
|
||||||
|
)
|
||||||
|
|
||||||
self.diff_files = diff_files
|
self.diff_files = diff_files
|
||||||
try:
|
try:
|
||||||
@ -269,8 +333,10 @@ class GithubProvider(GitProvider):
|
|||||||
return diff_files
|
return diff_files
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failing to get diff files: {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Failing to get diff files: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
|
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
|
||||||
|
|
||||||
def publish_description(self, pr_title: str, pr_body: str):
|
def publish_description(self, pr_title: str, pr_body: str):
|
||||||
@ -282,16 +348,23 @@ class GithubProvider(GitProvider):
|
|||||||
def get_comment_url(self, comment) -> str:
|
def get_comment_url(self, comment) -> str:
|
||||||
return comment.html_url
|
return comment.html_url
|
||||||
|
|
||||||
def publish_persistent_comment(self, pr_comment: str,
|
def publish_persistent_comment(
|
||||||
initial_header: str,
|
self,
|
||||||
update_header: bool = True,
|
pr_comment: str,
|
||||||
name='review',
|
initial_header: str,
|
||||||
final_update_message=True):
|
update_header: bool = True,
|
||||||
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
|
name='review',
|
||||||
|
final_update_message=True,
|
||||||
|
):
|
||||||
|
self.publish_persistent_comment_full(
|
||||||
|
pr_comment, initial_header, update_header, name, final_update_message
|
||||||
|
)
|
||||||
|
|
||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
if is_temporary and not get_settings().config.publish_output_progress:
|
if is_temporary and not get_settings().config.publish_output_progress:
|
||||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
get_logger().debug(
|
||||||
|
f"Skipping publish_comment for temporary comment: {pr_comment}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars)
|
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars)
|
||||||
response = self.pr.create_issue_comment(pr_comment)
|
response = self.pr.create_issue_comment(pr_comment)
|
||||||
@ -303,42 +376,68 @@ class GithubProvider(GitProvider):
|
|||||||
self.pr.comments_list.append(response)
|
self.pr.comments_list.append(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
original_suggestion=None,
|
||||||
|
):
|
||||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||||
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])
|
self.publish_inline_comments(
|
||||||
|
[self.create_inline_comment(body, relevant_file, relevant_line_in_file)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_inline_comment(
|
||||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str,
|
self,
|
||||||
absolute_position: int = None):
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
absolute_position: int = None,
|
||||||
|
):
|
||||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file(self.diff_files,
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
relevant_file.strip('`'),
|
self.diff_files,
|
||||||
relevant_line_in_file,
|
relevant_file.strip('`'),
|
||||||
absolute_position)
|
relevant_line_in_file,
|
||||||
|
absolute_position,
|
||||||
|
)
|
||||||
if position == -1:
|
if position == -1:
|
||||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
get_logger().info(
|
||||||
|
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||||
|
)
|
||||||
subject_type = "FILE"
|
subject_type = "FILE"
|
||||||
else:
|
else:
|
||||||
subject_type = "LINE"
|
subject_type = "LINE"
|
||||||
path = relevant_file.strip()
|
path = relevant_file.strip()
|
||||||
return dict(body=body, path=path, position=position) if subject_type == "LINE" else {}
|
return (
|
||||||
|
dict(body=body, path=path, position=position)
|
||||||
|
if subject_type == "LINE"
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = False):
|
def publish_inline_comments(
|
||||||
|
self, comments: list[dict], disable_fallback: bool = False
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# publish all comments in a single message
|
# publish all comments in a single message
|
||||||
self.pr.create_review(commit=self.last_commit_id, comments=comments)
|
self.pr.create_review(commit=self.last_commit_id, comments=comments)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().info(f"Initially failed to publish inline comments as committable")
|
get_logger().info(
|
||||||
|
f"Initially failed to publish inline comments as committable"
|
||||||
|
)
|
||||||
|
|
||||||
if (getattr(e, "status", None) == 422 and not disable_fallback):
|
if getattr(e, "status", None) == 422 and not disable_fallback:
|
||||||
pass # continue to try _publish_inline_comments_fallback_with_verification
|
pass # continue to try _publish_inline_comments_fallback_with_verification
|
||||||
else:
|
else:
|
||||||
raise e # will end up with publishing the comments one by one
|
raise e # will end up with publishing the comments one by one
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._publish_inline_comments_fallback_with_verification(comments)
|
self._publish_inline_comments_fallback_with_verification(comments)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to publish inline code comments fallback, error: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to publish inline code comments fallback, error: {e}"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]):
|
def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]):
|
||||||
@ -352,20 +451,27 @@ class GithubProvider(GitProvider):
|
|||||||
# publish as a group the verified comments
|
# publish as a group the verified comments
|
||||||
if verified_comments:
|
if verified_comments:
|
||||||
try:
|
try:
|
||||||
self.pr.create_review(commit=self.last_commit_id, comments=verified_comments)
|
self.pr.create_review(
|
||||||
|
commit=self.last_commit_id, comments=verified_comments
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# try to publish one by one the invalid comments as a one-line code comment
|
# try to publish one by one the invalid comments as a one-line code comment
|
||||||
if invalid_comments and get_settings().github.try_fix_invalid_inline_comments:
|
if invalid_comments and get_settings().github.try_fix_invalid_inline_comments:
|
||||||
fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments(
|
fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments(
|
||||||
[comment for comment, _ in invalid_comments])
|
[comment for comment, _ in invalid_comments]
|
||||||
|
)
|
||||||
for comment in fixed_comments_as_one_liner:
|
for comment in fixed_comments_as_one_liner:
|
||||||
try:
|
try:
|
||||||
self.publish_inline_comments([comment], disable_fallback=True)
|
self.publish_inline_comments([comment], disable_fallback=True)
|
||||||
get_logger().info(f"Published invalid comment as a single line comment: {comment}")
|
get_logger().info(
|
||||||
|
f"Published invalid comment as a single line comment: {comment}"
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
get_logger().error(f"Failed to publish invalid comment as a single line comment: {comment}")
|
get_logger().error(
|
||||||
|
f"Failed to publish invalid comment as a single line comment: {comment}"
|
||||||
|
)
|
||||||
|
|
||||||
def _verify_code_comment(self, comment: dict):
|
def _verify_code_comment(self, comment: dict):
|
||||||
is_verified = False
|
is_verified = False
|
||||||
@ -374,7 +480,8 @@ class GithubProvider(GitProvider):
|
|||||||
# event ="" # By leaving this blank, you set the review action state to PENDING
|
# event ="" # By leaving this blank, you set the review action state to PENDING
|
||||||
input = dict(commit_id=self.last_commit_id.sha, comments=[comment])
|
input = dict(commit_id=self.last_commit_id.sha, comments=[comment])
|
||||||
headers, data = self.pr._requester.requestJsonAndCheck(
|
headers, data = self.pr._requester.requestJsonAndCheck(
|
||||||
"POST", f"{self.pr.url}/reviews", input=input)
|
"POST", f"{self.pr.url}/reviews", input=input
|
||||||
|
)
|
||||||
pending_review_id = data["id"]
|
pending_review_id = data["id"]
|
||||||
is_verified = True
|
is_verified = True
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@ -383,12 +490,16 @@ class GithubProvider(GitProvider):
|
|||||||
e = err
|
e = err
|
||||||
if pending_review_id is not None:
|
if pending_review_id is not None:
|
||||||
try:
|
try:
|
||||||
self.pr._requester.requestJsonAndCheck("DELETE", f"{self.pr.url}/reviews/{pending_review_id}")
|
self.pr._requester.requestJsonAndCheck(
|
||||||
|
"DELETE", f"{self.pr.url}/reviews/{pending_review_id}"
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return is_verified, e
|
return is_verified, e
|
||||||
|
|
||||||
def _verify_code_comments(self, comments: list[dict]) -> tuple[list[dict], list[tuple[dict, Exception]]]:
|
def _verify_code_comments(
|
||||||
|
self, comments: list[dict]
|
||||||
|
) -> tuple[list[dict], list[tuple[dict, Exception]]]:
|
||||||
"""Very each comment against the GitHub API and return 2 lists: 1 of verified and 1 of invalid comments"""
|
"""Very each comment against the GitHub API and return 2 lists: 1 of verified and 1 of invalid comments"""
|
||||||
verified_comments = []
|
verified_comments = []
|
||||||
invalid_comments = []
|
invalid_comments = []
|
||||||
@ -401,17 +512,22 @@ class GithubProvider(GitProvider):
|
|||||||
invalid_comments.append((comment, e))
|
invalid_comments.append((comment, e))
|
||||||
return verified_comments, invalid_comments
|
return verified_comments, invalid_comments
|
||||||
|
|
||||||
def _try_fix_invalid_inline_comments(self, invalid_comments: list[dict]) -> list[dict]:
|
def _try_fix_invalid_inline_comments(
|
||||||
|
self, invalid_comments: list[dict]
|
||||||
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Try fixing invalid comments by removing the suggestion part and setting the comment just on the first line.
|
Try fixing invalid comments by removing the suggestion part and setting the comment just on the first line.
|
||||||
Return only comments that have been modified in some way.
|
Return only comments that have been modified in some way.
|
||||||
This is a best-effort attempt to fix invalid comments, and should be verified accordingly.
|
This is a best-effort attempt to fix invalid comments, and should be verified accordingly.
|
||||||
"""
|
"""
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
fixed_comments = []
|
fixed_comments = []
|
||||||
for comment in invalid_comments:
|
for comment in invalid_comments:
|
||||||
try:
|
try:
|
||||||
fixed_comment = copy.deepcopy(comment) # avoid modifying the original comment dict for later logging
|
fixed_comment = copy.deepcopy(
|
||||||
|
comment
|
||||||
|
) # avoid modifying the original comment dict for later logging
|
||||||
if "```suggestion" in comment["body"]:
|
if "```suggestion" in comment["body"]:
|
||||||
fixed_comment["body"] = comment["body"].split("```suggestion")[0]
|
fixed_comment["body"] = comment["body"].split("```suggestion")[0]
|
||||||
if "start_line" in comment:
|
if "start_line" in comment:
|
||||||
@ -432,7 +548,9 @@ class GithubProvider(GitProvider):
|
|||||||
"""
|
"""
|
||||||
post_parameters_list = []
|
post_parameters_list = []
|
||||||
|
|
||||||
code_suggestions_validated = self.validate_comments_inside_hunks(code_suggestions)
|
code_suggestions_validated = self.validate_comments_inside_hunks(
|
||||||
|
code_suggestions
|
||||||
|
)
|
||||||
|
|
||||||
for suggestion in code_suggestions_validated:
|
for suggestion in code_suggestions_validated:
|
||||||
body = suggestion['body']
|
body = suggestion['body']
|
||||||
@ -442,13 +560,16 @@ class GithubProvider(GitProvider):
|
|||||||
|
|
||||||
if not relevant_lines_start or relevant_lines_start == -1:
|
if not relevant_lines_start or relevant_lines_start == -1:
|
||||||
get_logger().exception(
|
get_logger().exception(
|
||||||
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
|
f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if relevant_lines_end < relevant_lines_start:
|
if relevant_lines_end < relevant_lines_start:
|
||||||
get_logger().exception(f"Failed to publish code suggestion, "
|
get_logger().exception(
|
||||||
f"relevant_lines_end is {relevant_lines_end} and "
|
f"Failed to publish code suggestion, "
|
||||||
f"relevant_lines_start is {relevant_lines_start}")
|
f"relevant_lines_end is {relevant_lines_end} and "
|
||||||
|
f"relevant_lines_start is {relevant_lines_start}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if relevant_lines_end > relevant_lines_start:
|
if relevant_lines_end > relevant_lines_start:
|
||||||
@ -484,17 +605,21 @@ class GithubProvider(GitProvider):
|
|||||||
# Log as warning for permission-related issues (usually due to polling)
|
# Log as warning for permission-related issues (usually due to polling)
|
||||||
get_logger().warning(
|
get_logger().warning(
|
||||||
"Failed to edit github comment due to permission restrictions",
|
"Failed to edit github comment due to permission restrictions",
|
||||||
artifact={"error": e})
|
artifact={"error": e},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().exception(f"Failed to edit github comment", artifact={"error": e})
|
get_logger().exception(
|
||||||
|
f"Failed to edit github comment", artifact={"error": e}
|
||||||
|
)
|
||||||
|
|
||||||
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
||||||
try:
|
try:
|
||||||
# self.pr.get_issue_comment(comment_id).edit(body)
|
# self.pr.get_issue_comment(comment_id).edit(body)
|
||||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||||
"PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
|
"PATCH",
|
||||||
input={"body": body}
|
f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}",
|
||||||
|
input={"body": body},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to edit comment, error: {e}")
|
get_logger().exception(f"Failed to edit comment, error: {e}")
|
||||||
@ -504,8 +629,9 @@ class GithubProvider(GitProvider):
|
|||||||
# self.pr.get_issue_comment(comment_id).edit(body)
|
# self.pr.get_issue_comment(comment_id).edit(body)
|
||||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||||
"POST", f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies",
|
"POST",
|
||||||
input={"body": body}
|
f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies",
|
||||||
|
input={"body": body},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to reply comment, error: {e}")
|
get_logger().exception(f"Failed to reply comment, error: {e}")
|
||||||
@ -516,7 +642,7 @@ class GithubProvider(GitProvider):
|
|||||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||||
"GET", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}"
|
"GET", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}"
|
||||||
)
|
)
|
||||||
return data_patch.get("body","")
|
return data_patch.get("body", "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to edit comment, error: {e}")
|
get_logger().exception(f"Failed to edit comment, error: {e}")
|
||||||
return None
|
return None
|
||||||
@ -528,7 +654,9 @@ class GithubProvider(GitProvider):
|
|||||||
)
|
)
|
||||||
for comment in file_comments:
|
for comment in file_comments:
|
||||||
comment['commit_id'] = self.last_commit_id.sha
|
comment['commit_id'] = self.last_commit_id.sha
|
||||||
comment['body'] = self.limit_output_characters(comment['body'], self.max_comment_chars)
|
comment['body'] = self.limit_output_characters(
|
||||||
|
comment['body'], self.max_comment_chars
|
||||||
|
)
|
||||||
|
|
||||||
found = False
|
found = False
|
||||||
for existing_comment in existing_comments:
|
for existing_comment in existing_comments:
|
||||||
@ -536,13 +664,23 @@ class GithubProvider(GitProvider):
|
|||||||
our_app_name = get_settings().get("GITHUB.APP_NAME", "")
|
our_app_name = get_settings().get("GITHUB.APP_NAME", "")
|
||||||
same_comment_creator = False
|
same_comment_creator = False
|
||||||
if self.deployment_type == 'app':
|
if self.deployment_type == 'app':
|
||||||
same_comment_creator = our_app_name.lower() in existing_comment['user']['login'].lower()
|
same_comment_creator = (
|
||||||
|
our_app_name.lower()
|
||||||
|
in existing_comment['user']['login'].lower()
|
||||||
|
)
|
||||||
elif self.deployment_type == 'user':
|
elif self.deployment_type == 'user':
|
||||||
same_comment_creator = self.github_user_id == existing_comment['user']['login']
|
same_comment_creator = (
|
||||||
if existing_comment['subject_type'] == 'file' and comment['path'] == existing_comment['path'] and same_comment_creator:
|
self.github_user_id == existing_comment['user']['login']
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
existing_comment['subject_type'] == 'file'
|
||||||
|
and comment['path'] == existing_comment['path']
|
||||||
|
and same_comment_creator
|
||||||
|
):
|
||||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||||
"PATCH", f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", input={"body":comment['body']}
|
"PATCH",
|
||||||
|
f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}",
|
||||||
|
input={"body": comment['body']},
|
||||||
)
|
)
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
@ -600,7 +738,9 @@ class GithubProvider(GitProvider):
|
|||||||
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
|
deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user")
|
||||||
|
|
||||||
if deployment_type != 'user':
|
if deployment_type != 'user':
|
||||||
raise ValueError("Deployment mode must be set to 'user' to get notifications")
|
raise ValueError(
|
||||||
|
"Deployment mode must be set to 'user' to get notifications"
|
||||||
|
)
|
||||||
|
|
||||||
notifications = self.github_client.get_user().get_notifications(since=since)
|
notifications = self.github_client.get_user().get_notifications(since=since)
|
||||||
return notifications
|
return notifications
|
||||||
@ -621,13 +761,16 @@ class GithubProvider(GitProvider):
|
|||||||
def get_workspace_name(self):
|
def get_workspace_name(self):
|
||||||
return self.repo.split('/')[0]
|
return self.repo.split('/')[0]
|
||||||
|
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
def add_eyes_reaction(
|
||||||
|
self, issue_comment_id: int, disable_eyes: bool = False
|
||||||
|
) -> Optional[int]:
|
||||||
if disable_eyes:
|
if disable_eyes:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||||
"POST", f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions",
|
"POST",
|
||||||
input={"content": "eyes"}
|
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions",
|
||||||
|
input={"content": "eyes"},
|
||||||
)
|
)
|
||||||
return data_patch.get("id", None)
|
return data_patch.get("id", None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -639,7 +782,7 @@ class GithubProvider(GitProvider):
|
|||||||
# self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id)
|
# self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id)
|
||||||
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
headers, data_patch = self.pr._requester.requestJsonAndCheck(
|
||||||
"DELETE",
|
"DELETE",
|
||||||
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}"
|
f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}",
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -655,7 +798,9 @@ class GithubProvider(GitProvider):
|
|||||||
path_parts = parsed_url.path.strip('/').split('/')
|
path_parts = parsed_url.path.strip('/').split('/')
|
||||||
if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url:
|
if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url:
|
||||||
if len(path_parts) < 5 or path_parts[3] != 'pulls':
|
if len(path_parts) < 5 or path_parts[3] != 'pulls':
|
||||||
raise ValueError("The provided URL does not appear to be a GitHub PR URL")
|
raise ValueError(
|
||||||
|
"The provided URL does not appear to be a GitHub PR URL"
|
||||||
|
)
|
||||||
repo_name = '/'.join(path_parts[1:3])
|
repo_name = '/'.join(path_parts[1:3])
|
||||||
try:
|
try:
|
||||||
pr_number = int(path_parts[4])
|
pr_number = int(path_parts[4])
|
||||||
@ -683,7 +828,9 @@ class GithubProvider(GitProvider):
|
|||||||
path_parts = parsed_url.path.strip('/').split('/')
|
path_parts = parsed_url.path.strip('/').split('/')
|
||||||
if 'api.github.com' in parsed_url.netloc:
|
if 'api.github.com' in parsed_url.netloc:
|
||||||
if len(path_parts) < 5 or path_parts[3] != 'issues':
|
if len(path_parts) < 5 or path_parts[3] != 'issues':
|
||||||
raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL")
|
raise ValueError(
|
||||||
|
"The provided URL does not appear to be a GitHub ISSUE URL"
|
||||||
|
)
|
||||||
repo_name = '/'.join(path_parts[1:3])
|
repo_name = '/'.join(path_parts[1:3])
|
||||||
try:
|
try:
|
||||||
issue_number = int(path_parts[4])
|
issue_number = int(path_parts[4])
|
||||||
@ -710,11 +857,18 @@ class GithubProvider(GitProvider):
|
|||||||
private_key = get_settings().github.private_key
|
private_key = get_settings().github.private_key
|
||||||
app_id = get_settings().github.app_id
|
app_id = get_settings().github.app_id
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e
|
raise ValueError(
|
||||||
|
"GitHub app ID and private key are required when using GitHub app deployment"
|
||||||
|
) from e
|
||||||
if not self.installation_id:
|
if not self.installation_id:
|
||||||
raise ValueError("GitHub app installation ID is required when using GitHub app deployment")
|
raise ValueError(
|
||||||
auth = AppAuthentication(app_id=app_id, private_key=private_key,
|
"GitHub app installation ID is required when using GitHub app deployment"
|
||||||
installation_id=self.installation_id)
|
)
|
||||||
|
auth = AppAuthentication(
|
||||||
|
app_id=app_id,
|
||||||
|
private_key=private_key,
|
||||||
|
installation_id=self.installation_id,
|
||||||
|
)
|
||||||
return Github(app_auth=auth, base_url=self.base_url)
|
return Github(app_auth=auth, base_url=self.base_url)
|
||||||
|
|
||||||
if deployment_type == 'user':
|
if deployment_type == 'user':
|
||||||
@ -723,19 +877,21 @@ class GithubProvider(GitProvider):
|
|||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"GitHub token is required when using user deployment. See: "
|
"GitHub token is required when using user deployment. See: "
|
||||||
"https://github.com/Codium-ai/pr-agent#method-2-run-from-source") from e
|
"https://github.com/Codium-ai/pr-agent#method-2-run-from-source"
|
||||||
|
) from e
|
||||||
return Github(auth=Auth.Token(token), base_url=self.base_url)
|
return Github(auth=Auth.Token(token), base_url=self.base_url)
|
||||||
|
|
||||||
def _get_repo(self):
|
def _get_repo(self):
|
||||||
if hasattr(self, 'repo_obj') and \
|
if (
|
||||||
hasattr(self.repo_obj, 'full_name') and \
|
hasattr(self, 'repo_obj')
|
||||||
self.repo_obj.full_name == self.repo:
|
and hasattr(self.repo_obj, 'full_name')
|
||||||
|
and self.repo_obj.full_name == self.repo
|
||||||
|
):
|
||||||
return self.repo_obj
|
return self.repo_obj
|
||||||
else:
|
else:
|
||||||
self.repo_obj = self.github_client.get_repo(self.repo)
|
self.repo_obj = self.github_client.get_repo(self.repo)
|
||||||
return self.repo_obj
|
return self.repo_obj
|
||||||
|
|
||||||
|
|
||||||
def _get_pr(self):
|
def _get_pr(self):
|
||||||
return self._get_repo().get_pull(self.pr_num)
|
return self._get_repo().get_pull(self.pr_num)
|
||||||
|
|
||||||
@ -755,9 +911,9 @@ class GithubProvider(GitProvider):
|
|||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
file_obj = self._get_repo().get_contents(file_path, ref=branch)
|
file_obj = self._get_repo().get_contents(file_path, ref=branch)
|
||||||
sha1=file_obj.sha
|
sha1 = file_obj.sha
|
||||||
except Exception:
|
except Exception:
|
||||||
sha1=""
|
sha1 = ""
|
||||||
self.repo_obj.update_file(
|
self.repo_obj.update_file(
|
||||||
path=file_path,
|
path=file_path,
|
||||||
message=message,
|
message=message,
|
||||||
@ -771,9 +927,14 @@ class GithubProvider(GitProvider):
|
|||||||
|
|
||||||
def publish_labels(self, pr_types):
|
def publish_labels(self, pr_types):
|
||||||
try:
|
try:
|
||||||
label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5",
|
label_color_map = {
|
||||||
"Enhancement": "bfd4f2", "Documentation": "d4c5f9",
|
"Bug fix": "1d76db",
|
||||||
"Other": "d1bcf9"}
|
"Tests": "e99695",
|
||||||
|
"Bug fix with tests": "c5def5",
|
||||||
|
"Enhancement": "bfd4f2",
|
||||||
|
"Documentation": "d4c5f9",
|
||||||
|
"Other": "d1bcf9",
|
||||||
|
}
|
||||||
post_parameters = []
|
post_parameters = []
|
||||||
for p in pr_types:
|
for p in pr_types:
|
||||||
color = label_color_map.get(p, "d1bcf9") # default to "Other" color
|
color = label_color_map.get(p, "d1bcf9") # default to "Other" color
|
||||||
@ -787,11 +948,12 @@ class GithubProvider(GitProvider):
|
|||||||
def get_pr_labels(self, update=False):
|
def get_pr_labels(self, update=False):
|
||||||
try:
|
try:
|
||||||
if not update:
|
if not update:
|
||||||
labels =self.pr.labels
|
labels = self.pr.labels
|
||||||
return [label.name for label in labels]
|
return [label.name for label in labels]
|
||||||
else: # obtain the latest labels. Maybe they changed while the AI was running
|
else: # obtain the latest labels. Maybe they changed while the AI was running
|
||||||
headers, labels = self.pr._requester.requestJsonAndCheck(
|
headers, labels = self.pr._requester.requestJsonAndCheck(
|
||||||
"GET", f"{self.pr.issue_url}/labels")
|
"GET", f"{self.pr.issue_url}/labels"
|
||||||
|
)
|
||||||
return [label['name'] for label in labels]
|
return [label['name'] for label in labels]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -813,7 +975,9 @@ class GithubProvider(GitProvider):
|
|||||||
try:
|
try:
|
||||||
commit_list = self.pr.get_commits()
|
commit_list = self.pr.get_commits()
|
||||||
commit_messages = [commit.commit.message for commit in commit_list]
|
commit_messages = [commit.commit.message for commit in commit_list]
|
||||||
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)])
|
commit_messages_str = "\n".join(
|
||||||
|
[f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
commit_messages_str = ""
|
commit_messages_str = ""
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
@ -822,13 +986,16 @@ class GithubProvider(GitProvider):
|
|||||||
|
|
||||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||||
try:
|
try:
|
||||||
relevant_file = suggestion['relevant_file'].strip('`').strip("'").strip('\n')
|
relevant_file = (
|
||||||
|
suggestion['relevant_file'].strip('`').strip("'").strip('\n')
|
||||||
|
)
|
||||||
relevant_line_str = suggestion['relevant_line'].strip('\n')
|
relevant_line_str = suggestion['relevant_line'].strip('\n')
|
||||||
if not relevant_line_str:
|
if not relevant_line_str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
(self.diff_files, relevant_file, relevant_line_str)
|
self.diff_files, relevant_file, relevant_line_str
|
||||||
|
)
|
||||||
|
|
||||||
if absolute_position != -1:
|
if absolute_position != -1:
|
||||||
# # link to right file only
|
# # link to right file only
|
||||||
@ -844,7 +1011,12 @@ class GithubProvider(GitProvider):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
def get_line_link(
|
||||||
|
self,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_start: int,
|
||||||
|
relevant_line_end: int = None,
|
||||||
|
) -> str:
|
||||||
sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest()
|
sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest()
|
||||||
if relevant_line_start == -1:
|
if relevant_line_start == -1:
|
||||||
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}"
|
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}"
|
||||||
@ -854,7 +1026,9 @@ class GithubProvider(GitProvider):
|
|||||||
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}"
|
link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{relevant_line_start}"
|
||||||
return link
|
return link
|
||||||
|
|
||||||
def get_lines_link_original_file(self, filepath: str, component_range: Range) -> str:
|
def get_lines_link_original_file(
|
||||||
|
self, filepath: str, component_range: Range
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Returns the link to the original file on GitHub that corresponds to the given filepath and component range.
|
Returns the link to the original file on GitHub that corresponds to the given filepath and component range.
|
||||||
|
|
||||||
@ -876,8 +1050,10 @@ class GithubProvider(GitProvider):
|
|||||||
line_end = component_range.line_end + 1
|
line_end = component_range.line_end + 1
|
||||||
# link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
|
# link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
|
||||||
# f"#L{line_start}-L{line_end}")
|
# f"#L{line_start}-L{line_end}")
|
||||||
link = (f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
|
link = (
|
||||||
f"#L{line_start}-L{line_end}")
|
f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/"
|
||||||
|
f"#L{line_start}-L{line_end}"
|
||||||
|
)
|
||||||
|
|
||||||
return link
|
return link
|
||||||
|
|
||||||
@ -909,8 +1085,9 @@ class GithubProvider(GitProvider):
|
|||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql",
|
response_tuple = self.github_client._Github__requester.requestJson(
|
||||||
input={"query": query})
|
"POST", "/graphql", input={"query": query}
|
||||||
|
)
|
||||||
|
|
||||||
# Extract the JSON response from the tuple and parses it
|
# Extract the JSON response from the tuple and parses it
|
||||||
if isinstance(response_tuple, tuple) and len(response_tuple) == 3:
|
if isinstance(response_tuple, tuple) and len(response_tuple) == 3:
|
||||||
@ -919,8 +1096,12 @@ class GithubProvider(GitProvider):
|
|||||||
get_logger().error(f"Unexpected response format: {response_tuple}")
|
get_logger().error(f"Unexpected response format: {response_tuple}")
|
||||||
return sub_issues
|
return sub_issues
|
||||||
|
|
||||||
|
issue_id = (
|
||||||
issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id")
|
response_json.get("data", {})
|
||||||
|
.get("repository", {})
|
||||||
|
.get("issue", {})
|
||||||
|
.get("id")
|
||||||
|
)
|
||||||
|
|
||||||
if not issue_id:
|
if not issue_id:
|
||||||
get_logger().warning(f"Issue ID not found for {issue_url}")
|
get_logger().warning(f"Issue ID not found for {issue_url}")
|
||||||
@ -940,22 +1121,42 @@ class GithubProvider(GitProvider):
|
|||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={
|
sub_issues_response_tuple = (
|
||||||
"query": sub_issues_query})
|
self.github_client._Github__requester.requestJson(
|
||||||
|
"POST", "/graphql", input={"query": sub_issues_query}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Extract the JSON response from the tuple and parses it
|
# Extract the JSON response from the tuple and parses it
|
||||||
if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3:
|
if (
|
||||||
|
isinstance(sub_issues_response_tuple, tuple)
|
||||||
|
and len(sub_issues_response_tuple) == 3
|
||||||
|
):
|
||||||
sub_issues_response_json = json.loads(sub_issues_response_tuple[2])
|
sub_issues_response_json = json.loads(sub_issues_response_tuple[2])
|
||||||
else:
|
else:
|
||||||
get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple})
|
get_logger().error(
|
||||||
|
"Unexpected sub-issues response format",
|
||||||
|
artifact={"response": sub_issues_response_tuple},
|
||||||
|
)
|
||||||
return sub_issues
|
return sub_issues
|
||||||
|
|
||||||
if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"):
|
if (
|
||||||
|
not sub_issues_response_json.get("data", {})
|
||||||
|
.get("node", {})
|
||||||
|
.get("subIssues")
|
||||||
|
):
|
||||||
get_logger().error("Invalid sub-issues response structure")
|
get_logger().error("Invalid sub-issues response structure")
|
||||||
return sub_issues
|
return sub_issues
|
||||||
|
|
||||||
nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", [])
|
nodes = (
|
||||||
get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes})
|
sub_issues_response_json.get("data", {})
|
||||||
|
.get("node", {})
|
||||||
|
.get("subIssues", {})
|
||||||
|
.get("nodes", [])
|
||||||
|
)
|
||||||
|
get_logger().info(
|
||||||
|
f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}
|
||||||
|
)
|
||||||
|
|
||||||
for sub_issue in nodes:
|
for sub_issue in nodes:
|
||||||
if "url" in sub_issue:
|
if "url" in sub_issue:
|
||||||
@ -977,7 +1178,7 @@ class GithubProvider(GitProvider):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def calc_pr_statistics(self, pull_request_data: dict):
|
def calc_pr_statistics(self, pull_request_data: dict):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def validate_comments_inside_hunks(self, code_suggestions):
|
def validate_comments_inside_hunks(self, code_suggestions):
|
||||||
"""
|
"""
|
||||||
@ -986,7 +1187,8 @@ class GithubProvider(GitProvider):
|
|||||||
code_suggestions_copy = copy.deepcopy(code_suggestions)
|
code_suggestions_copy = copy.deepcopy(code_suggestions)
|
||||||
diff_files = self.get_diff_files()
|
diff_files = self.get_diff_files()
|
||||||
RE_HUNK_HEADER = re.compile(
|
RE_HUNK_HEADER = re.compile(
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
|
||||||
|
)
|
||||||
|
|
||||||
diff_files = set_file_languages(diff_files)
|
diff_files = set_file_languages(diff_files)
|
||||||
|
|
||||||
@ -995,7 +1197,6 @@ class GithubProvider(GitProvider):
|
|||||||
relevant_file_path = suggestion['relevant_file']
|
relevant_file_path = suggestion['relevant_file']
|
||||||
for file in diff_files:
|
for file in diff_files:
|
||||||
if file.filename == relevant_file_path:
|
if file.filename == relevant_file_path:
|
||||||
|
|
||||||
# generate on-demand the patches range for the relevant file
|
# generate on-demand the patches range for the relevant file
|
||||||
patch_str = file.patch
|
patch_str = file.patch
|
||||||
if not hasattr(file, 'patches_range'):
|
if not hasattr(file, 'patches_range'):
|
||||||
@ -1006,14 +1207,30 @@ class GithubProvider(GitProvider):
|
|||||||
match = RE_HUNK_HEADER.match(line)
|
match = RE_HUNK_HEADER.match(line)
|
||||||
# identify hunk header
|
# identify hunk header
|
||||||
if match:
|
if match:
|
||||||
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
|
(
|
||||||
file.patches_range.append({'start': start2, 'end': start2 + size2 - 1})
|
section_header,
|
||||||
|
size1,
|
||||||
|
size2,
|
||||||
|
start1,
|
||||||
|
start2,
|
||||||
|
) = extract_hunk_headers(match)
|
||||||
|
file.patches_range.append(
|
||||||
|
{'start': start2, 'end': start2 + size2 - 1}
|
||||||
|
)
|
||||||
|
|
||||||
patches_range = file.patches_range
|
patches_range = file.patches_range
|
||||||
comment_start_line = suggestion.get('relevant_lines_start', None)
|
comment_start_line = suggestion.get(
|
||||||
|
'relevant_lines_start', None
|
||||||
|
)
|
||||||
comment_end_line = suggestion.get('relevant_lines_end', None)
|
comment_end_line = suggestion.get('relevant_lines_end', None)
|
||||||
original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code
|
original_suggestion = suggestion.get(
|
||||||
if not comment_start_line or not comment_end_line or not original_suggestion:
|
'original_suggestion', None
|
||||||
|
) # needed for diff code
|
||||||
|
if (
|
||||||
|
not comment_start_line
|
||||||
|
or not comment_end_line
|
||||||
|
or not original_suggestion
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check if the comment is inside a valid hunk
|
# check if the comment is inside a valid hunk
|
||||||
@ -1037,30 +1254,57 @@ class GithubProvider(GitProvider):
|
|||||||
patch_range_min = patch_range
|
patch_range_min = patch_range
|
||||||
min_distance = min(min_distance, d)
|
min_distance = min(min_distance, d)
|
||||||
if not is_valid_hunk:
|
if not is_valid_hunk:
|
||||||
if min_distance < 10: # 10 lines - a reasonable distance to consider the comment inside the hunk
|
if (
|
||||||
|
min_distance < 10
|
||||||
|
): # 10 lines - a reasonable distance to consider the comment inside the hunk
|
||||||
# make the suggestion non-committable, yet multi line
|
# make the suggestion non-committable, yet multi line
|
||||||
suggestion['relevant_lines_start'] = max(suggestion['relevant_lines_start'], patch_range_min['start'])
|
suggestion['relevant_lines_start'] = max(
|
||||||
suggestion['relevant_lines_end'] = min(suggestion['relevant_lines_end'], patch_range_min['end'])
|
suggestion['relevant_lines_start'],
|
||||||
|
patch_range_min['start'],
|
||||||
|
)
|
||||||
|
suggestion['relevant_lines_end'] = min(
|
||||||
|
suggestion['relevant_lines_end'],
|
||||||
|
patch_range_min['end'],
|
||||||
|
)
|
||||||
body = suggestion['body'].strip()
|
body = suggestion['body'].strip()
|
||||||
|
|
||||||
# present new diff code in collapsible
|
# present new diff code in collapsible
|
||||||
existing_code = original_suggestion['existing_code'].rstrip() + "\n"
|
existing_code = (
|
||||||
improved_code = original_suggestion['improved_code'].rstrip() + "\n"
|
original_suggestion['existing_code'].rstrip() + "\n"
|
||||||
diff = difflib.unified_diff(existing_code.split('\n'),
|
)
|
||||||
improved_code.split('\n'), n=999)
|
improved_code = (
|
||||||
|
original_suggestion['improved_code'].rstrip() + "\n"
|
||||||
|
)
|
||||||
|
diff = difflib.unified_diff(
|
||||||
|
existing_code.split('\n'),
|
||||||
|
improved_code.split('\n'),
|
||||||
|
n=999,
|
||||||
|
)
|
||||||
patch_orig = "\n".join(diff)
|
patch_orig = "\n".join(diff)
|
||||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
patch = "\n".join(patch_orig.splitlines()[5:]).strip(
|
||||||
|
'\n'
|
||||||
|
)
|
||||||
diff_code = f"\n\n<details><summary>新提议的代码:</summary>\n\n```diff\n{patch.rstrip()}\n```"
|
diff_code = f"\n\n<details><summary>新提议的代码:</summary>\n\n```diff\n{patch.rstrip()}\n```"
|
||||||
# replace ```suggestion ... ``` with diff_code, using regex:
|
# replace ```suggestion ... ``` with diff_code, using regex:
|
||||||
body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL)
|
body = re.sub(
|
||||||
|
r'```suggestion.*?```',
|
||||||
|
diff_code,
|
||||||
|
body,
|
||||||
|
flags=re.DOTALL,
|
||||||
|
)
|
||||||
body += "\n\n</details>"
|
body += "\n\n</details>"
|
||||||
suggestion['body'] = body
|
suggestion['body'] = body
|
||||||
get_logger().info(f"Comment was moved to a valid hunk, "
|
get_logger().info(
|
||||||
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}")
|
f"Comment was moved to a valid hunk, "
|
||||||
|
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().error(f"Comment is not inside a valid hunk, "
|
get_logger().error(
|
||||||
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}")
|
f"Comment is not inside a valid hunk, "
|
||||||
|
f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to process patch for committable comment, error: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to process patch for committable comment, error: {e}"
|
||||||
|
)
|
||||||
return code_suggestions_copy
|
return code_suggestions_copy
|
||||||
|
|
||||||
|
|||||||
@ -10,9 +10,11 @@ from utils.pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
|||||||
|
|
||||||
from ..algo.file_filter import filter_ignored
|
from ..algo.file_filter import filter_ignored
|
||||||
from ..algo.language_handler import is_valid_file
|
from ..algo.language_handler import is_valid_file
|
||||||
from ..algo.utils import (clip_tokens,
|
from ..algo.utils import (
|
||||||
find_line_number_of_relevant_line_in_file,
|
clip_tokens,
|
||||||
load_large_diff)
|
find_line_number_of_relevant_line_in_file,
|
||||||
|
load_large_diff,
|
||||||
|
)
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from ..log import get_logger
|
from ..log import get_logger
|
||||||
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
|
from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
|
||||||
@ -20,22 +22,26 @@ from .git_provider import MAX_FILES_ALLOWED_FULL, GitProvider
|
|||||||
|
|
||||||
class DiffNotFoundError(Exception):
|
class DiffNotFoundError(Exception):
|
||||||
"""Raised when the diff for a merge request cannot be found."""
|
"""Raised when the diff for a merge request cannot be found."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class GitLabProvider(GitProvider):
|
|
||||||
|
|
||||||
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
|
class GitLabProvider(GitProvider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
merge_request_url: Optional[str] = None,
|
||||||
|
incremental: Optional[bool] = False,
|
||||||
|
):
|
||||||
gitlab_url = get_settings().get("GITLAB.URL", None)
|
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||||
if not gitlab_url:
|
if not gitlab_url:
|
||||||
raise ValueError("GitLab URL is not set in the config file")
|
raise ValueError("GitLab URL is not set in the config file")
|
||||||
self.gitlab_url = gitlab_url
|
self.gitlab_url = gitlab_url
|
||||||
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
gitlab_access_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||||
if not gitlab_access_token:
|
if not gitlab_access_token:
|
||||||
raise ValueError("GitLab personal access token is not set in the config file")
|
raise ValueError(
|
||||||
self.gl = gitlab.Gitlab(
|
"GitLab personal access token is not set in the config file"
|
||||||
url=gitlab_url,
|
)
|
||||||
oauth_token=gitlab_access_token
|
self.gl = gitlab.Gitlab(url=gitlab_url, oauth_token=gitlab_access_token)
|
||||||
)
|
|
||||||
self.max_comment_chars = 65000
|
self.max_comment_chars = 65000
|
||||||
self.id_project = None
|
self.id_project = None
|
||||||
self.id_mr = None
|
self.id_mr = None
|
||||||
@ -46,12 +52,17 @@ class GitLabProvider(GitProvider):
|
|||||||
self.pr_url = merge_request_url
|
self.pr_url = merge_request_url
|
||||||
self._set_merge_request(merge_request_url)
|
self._set_merge_request(merge_request_url)
|
||||||
self.RE_HUNK_HEADER = re.compile(
|
self.RE_HUNK_HEADER = re.compile(
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)"
|
||||||
|
)
|
||||||
self.incremental = incremental
|
self.incremental = incremental
|
||||||
|
|
||||||
def is_supported(self, capability: str) -> bool:
|
def is_supported(self, capability: str) -> bool:
|
||||||
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments',
|
if capability in [
|
||||||
'publish_file_comments']: # gfm_markdown is supported in gitlab !
|
'get_issue_comments',
|
||||||
|
'create_inline_comment',
|
||||||
|
'publish_inline_comments',
|
||||||
|
'publish_file_comments',
|
||||||
|
]: # gfm_markdown is supported in gitlab !
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -67,12 +78,17 @@ class GitLabProvider(GitProvider):
|
|||||||
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
|
self.last_diff = self.mr.diffs.list(get_all=True)[-1]
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
||||||
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}") from e
|
raise DiffNotFoundError(
|
||||||
|
f"Could not get diff for merge request {self.id_mr}"
|
||||||
|
) from e
|
||||||
|
|
||||||
def get_pr_file_content(self, file_path: str, branch: str) -> str:
|
def get_pr_file_content(self, file_path: str, branch: str) -> str:
|
||||||
try:
|
try:
|
||||||
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
|
return (
|
||||||
|
self.gl.projects.get(self.id_project)
|
||||||
|
.files.get(file_path, branch)
|
||||||
|
.decode()
|
||||||
|
)
|
||||||
except GitlabGetError:
|
except GitlabGetError:
|
||||||
# In case of file creation the method returns GitlabGetError (404 file not found).
|
# In case of file creation the method returns GitlabGetError (404 file not found).
|
||||||
# In this case we return an empty string for the diff.
|
# In this case we return an empty string for the diff.
|
||||||
@ -98,10 +114,13 @@ class GitLabProvider(GitProvider):
|
|||||||
try:
|
try:
|
||||||
names_original = [diff['new_path'] for diff in diffs_original]
|
names_original = [diff['new_path'] for diff in diffs_original]
|
||||||
names_filtered = [diff['new_path'] for diff in diffs]
|
names_filtered = [diff['new_path'] for diff in diffs]
|
||||||
get_logger().info(f"Filtered out [ignore] files for merge request {self.id_mr}", extra={
|
get_logger().info(
|
||||||
'original_files': names_original,
|
f"Filtered out [ignore] files for merge request {self.id_mr}",
|
||||||
'filtered_files': names_filtered
|
extra={
|
||||||
})
|
'original_files': names_original,
|
||||||
|
'filtered_files': names_filtered,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -116,22 +135,31 @@ class GitLabProvider(GitProvider):
|
|||||||
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
|
# allow only a limited number of files to be fully loaded. We can manage the rest with diffs only
|
||||||
counter_valid += 1
|
counter_valid += 1
|
||||||
if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']:
|
if counter_valid < MAX_FILES_ALLOWED_FULL or not diff['diff']:
|
||||||
original_file_content_str = self.get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha'])
|
original_file_content_str = self.get_pr_file_content(
|
||||||
new_file_content_str = self.get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha'])
|
diff['old_path'], self.mr.diff_refs['base_sha']
|
||||||
|
)
|
||||||
|
new_file_content_str = self.get_pr_file_content(
|
||||||
|
diff['new_path'], self.mr.diff_refs['head_sha']
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if counter_valid == MAX_FILES_ALLOWED_FULL:
|
if counter_valid == MAX_FILES_ALLOWED_FULL:
|
||||||
get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files")
|
get_logger().info(
|
||||||
|
f"Too many files in PR, will avoid loading full content for rest of files"
|
||||||
|
)
|
||||||
original_file_content_str = ''
|
original_file_content_str = ''
|
||||||
new_file_content_str = ''
|
new_file_content_str = ''
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(original_file_content_str, bytes):
|
if isinstance(original_file_content_str, bytes):
|
||||||
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
|
original_file_content_str = bytes.decode(
|
||||||
|
original_file_content_str, 'utf-8'
|
||||||
|
)
|
||||||
if isinstance(new_file_content_str, bytes):
|
if isinstance(new_file_content_str, bytes):
|
||||||
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
|
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
get_logger().warning(
|
get_logger().warning(
|
||||||
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
|
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}"
|
||||||
|
)
|
||||||
|
|
||||||
edit_type = EDIT_TYPE.MODIFIED
|
edit_type = EDIT_TYPE.MODIFIED
|
||||||
if diff['new_file']:
|
if diff['new_file']:
|
||||||
@ -144,30 +172,43 @@ class GitLabProvider(GitProvider):
|
|||||||
filename = diff['new_path']
|
filename = diff['new_path']
|
||||||
patch = diff['diff']
|
patch = diff['diff']
|
||||||
if not patch:
|
if not patch:
|
||||||
patch = load_large_diff(filename, new_file_content_str, original_file_content_str)
|
patch = load_large_diff(
|
||||||
|
filename, new_file_content_str, original_file_content_str
|
||||||
|
)
|
||||||
|
|
||||||
# count number of lines added and removed
|
# count number of lines added and removed
|
||||||
patch_lines = patch.splitlines(keepends=True)
|
patch_lines = patch.splitlines(keepends=True)
|
||||||
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
num_plus_lines = len([line for line in patch_lines if line.startswith('+')])
|
||||||
num_minus_lines = len([line for line in patch_lines if line.startswith('-')])
|
num_minus_lines = len(
|
||||||
|
[line for line in patch_lines if line.startswith('-')]
|
||||||
|
)
|
||||||
diff_files.append(
|
diff_files.append(
|
||||||
FilePatchInfo(original_file_content_str, new_file_content_str,
|
FilePatchInfo(
|
||||||
patch=patch,
|
original_file_content_str,
|
||||||
filename=filename,
|
new_file_content_str,
|
||||||
edit_type=edit_type,
|
patch=patch,
|
||||||
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path'],
|
filename=filename,
|
||||||
num_plus_lines=num_plus_lines,
|
edit_type=edit_type,
|
||||||
num_minus_lines=num_minus_lines, ))
|
old_filename=None
|
||||||
|
if diff['old_path'] == diff['new_path']
|
||||||
|
else diff['old_path'],
|
||||||
|
num_plus_lines=num_plus_lines,
|
||||||
|
num_minus_lines=num_minus_lines,
|
||||||
|
)
|
||||||
|
)
|
||||||
if invalid_files_names:
|
if invalid_files_names:
|
||||||
get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}")
|
get_logger().info(
|
||||||
|
f"Filtered out files with invalid extensions: {invalid_files_names}"
|
||||||
|
)
|
||||||
|
|
||||||
self.diff_files = diff_files
|
self.diff_files = diff_files
|
||||||
return diff_files
|
return diff_files
|
||||||
|
|
||||||
def get_files(self) -> list:
|
def get_files(self) -> list:
|
||||||
if not self.git_files:
|
if not self.git_files:
|
||||||
self.git_files = [change['new_path'] for change in self.mr.changes()['changes']]
|
self.git_files = [
|
||||||
|
change['new_path'] for change in self.mr.changes()['changes']
|
||||||
|
]
|
||||||
return self.git_files
|
return self.git_files
|
||||||
|
|
||||||
def publish_description(self, pr_title: str, pr_body: str):
|
def publish_description(self, pr_title: str, pr_body: str):
|
||||||
@ -176,7 +217,9 @@ class GitLabProvider(GitProvider):
|
|||||||
self.mr.description = pr_body
|
self.mr.description = pr_body
|
||||||
self.mr.save()
|
self.mr.save()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Could not update merge request {self.id_mr} description: {e}")
|
get_logger().exception(
|
||||||
|
f"Could not update merge request {self.id_mr} description: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_latest_commit_url(self):
|
def get_latest_commit_url(self):
|
||||||
return self.mr.commits().next().web_url
|
return self.mr.commits().next().web_url
|
||||||
@ -184,16 +227,23 @@ class GitLabProvider(GitProvider):
|
|||||||
def get_comment_url(self, comment):
|
def get_comment_url(self, comment):
|
||||||
return f"{self.mr.web_url}#note_{comment.id}"
|
return f"{self.mr.web_url}#note_{comment.id}"
|
||||||
|
|
||||||
def publish_persistent_comment(self, pr_comment: str,
|
def publish_persistent_comment(
|
||||||
initial_header: str,
|
self,
|
||||||
update_header: bool = True,
|
pr_comment: str,
|
||||||
name='review',
|
initial_header: str,
|
||||||
final_update_message=True):
|
update_header: bool = True,
|
||||||
self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message)
|
name='review',
|
||||||
|
final_update_message=True,
|
||||||
|
):
|
||||||
|
self.publish_persistent_comment_full(
|
||||||
|
pr_comment, initial_header, update_header, name, final_update_message
|
||||||
|
)
|
||||||
|
|
||||||
def publish_comment(self, mr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, mr_comment: str, is_temporary: bool = False):
|
||||||
if is_temporary and not get_settings().config.publish_output_progress:
|
if is_temporary and not get_settings().config.publish_output_progress:
|
||||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {mr_comment}")
|
get_logger().debug(
|
||||||
|
f"Skipping publish_comment for temporary comment: {mr_comment}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars)
|
mr_comment = self.limit_output_characters(mr_comment, self.max_comment_chars)
|
||||||
comment = self.mr.notes.create({'body': mr_comment})
|
comment = self.mr.notes.create({'body': mr_comment})
|
||||||
@ -203,7 +253,7 @@ class GitLabProvider(GitProvider):
|
|||||||
|
|
||||||
def edit_comment(self, comment, body: str):
|
def edit_comment(self, comment, body: str):
|
||||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||||
self.mr.notes.update(comment.id,{'body': body} )
|
self.mr.notes.update(comment.id, {'body': body})
|
||||||
|
|
||||||
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
def edit_comment_from_comment_id(self, comment_id: int, body: str):
|
||||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||||
@ -216,39 +266,87 @@ class GitLabProvider(GitProvider):
|
|||||||
discussion = self.mr.discussions.get(comment_id)
|
discussion = self.mr.discussions.get(comment_id)
|
||||||
discussion.notes.create({'body': body})
|
discussion.notes.create({'body': body})
|
||||||
|
|
||||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
original_suggestion=None,
|
||||||
|
):
|
||||||
body = self.limit_output_characters(body, self.max_comment_chars)
|
body = self.limit_output_characters(body, self.max_comment_chars)
|
||||||
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
|
(
|
||||||
relevant_line_in_file)
|
edit_type,
|
||||||
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
found,
|
||||||
target_file, target_line_no, original_suggestion)
|
source_line_no,
|
||||||
|
target_file,
|
||||||
|
target_line_no,
|
||||||
|
) = self.search_line(relevant_file, relevant_line_in_file)
|
||||||
|
self.send_inline_comment(
|
||||||
|
body,
|
||||||
|
edit_type,
|
||||||
|
found,
|
||||||
|
relevant_file,
|
||||||
|
relevant_line_in_file,
|
||||||
|
source_line_no,
|
||||||
|
target_file,
|
||||||
|
target_line_no,
|
||||||
|
original_suggestion,
|
||||||
|
)
|
||||||
|
|
||||||
def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None):
|
def create_inline_comment(
|
||||||
raise NotImplementedError("Gitlab provider does not support creating inline comments yet")
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
absolute_position: int = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Gitlab provider does not support creating inline comments yet"
|
||||||
|
)
|
||||||
|
|
||||||
def create_inline_comments(self, comments: list[dict]):
|
def create_inline_comments(self, comments: list[dict]):
|
||||||
raise NotImplementedError("Gitlab provider does not support publishing inline comments yet")
|
raise NotImplementedError(
|
||||||
|
"Gitlab provider does not support publishing inline comments yet"
|
||||||
|
)
|
||||||
|
|
||||||
def get_comment_body_from_comment_id(self, comment_id: int):
|
def get_comment_body_from_comment_id(self, comment_id: int):
|
||||||
comment = self.mr.notes.get(comment_id).body
|
comment = self.mr.notes.get(comment_id).body
|
||||||
return comment
|
return comment
|
||||||
|
|
||||||
def send_inline_comment(self, body: str, edit_type: str, found: bool, relevant_file: str,
|
def send_inline_comment(
|
||||||
relevant_line_in_file: str,
|
self,
|
||||||
source_line_no: int, target_file: str, target_line_no: int,
|
body: str,
|
||||||
original_suggestion=None) -> None:
|
edit_type: str,
|
||||||
|
found: bool,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
source_line_no: int,
|
||||||
|
target_file: str,
|
||||||
|
target_line_no: int,
|
||||||
|
original_suggestion=None,
|
||||||
|
) -> None:
|
||||||
if not found:
|
if not found:
|
||||||
get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
get_logger().info(
|
||||||
|
f"Could not find position for {relevant_file} {relevant_line_in_file}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# in order to have exact sha's we have to find correct diff for this change
|
# in order to have exact sha's we have to find correct diff for this change
|
||||||
diff = self.get_relevant_diff(relevant_file, relevant_line_in_file)
|
diff = self.get_relevant_diff(relevant_file, relevant_line_in_file)
|
||||||
if diff is None:
|
if diff is None:
|
||||||
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
get_logger().error(f"Could not get diff for merge request {self.id_mr}")
|
||||||
raise DiffNotFoundError(f"Could not get diff for merge request {self.id_mr}")
|
raise DiffNotFoundError(
|
||||||
pos_obj = {'position_type': 'text',
|
f"Could not get diff for merge request {self.id_mr}"
|
||||||
'new_path': target_file.filename,
|
)
|
||||||
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
pos_obj = {
|
||||||
'base_sha': diff.base_commit_sha, 'start_sha': diff.start_commit_sha, 'head_sha': diff.head_commit_sha}
|
'position_type': 'text',
|
||||||
|
'new_path': target_file.filename,
|
||||||
|
'old_path': target_file.old_filename
|
||||||
|
if target_file.old_filename
|
||||||
|
else target_file.filename,
|
||||||
|
'base_sha': diff.base_commit_sha,
|
||||||
|
'start_sha': diff.start_commit_sha,
|
||||||
|
'head_sha': diff.head_commit_sha,
|
||||||
|
}
|
||||||
if edit_type == 'deletion':
|
if edit_type == 'deletion':
|
||||||
pos_obj['old_line'] = source_line_no - 1
|
pos_obj['old_line'] = source_line_no - 1
|
||||||
elif edit_type == 'addition':
|
elif edit_type == 'addition':
|
||||||
@ -256,15 +354,21 @@ class GitLabProvider(GitProvider):
|
|||||||
else:
|
else:
|
||||||
pos_obj['new_line'] = target_line_no - 1
|
pos_obj['new_line'] = target_line_no - 1
|
||||||
pos_obj['old_line'] = source_line_no - 1
|
pos_obj['old_line'] = source_line_no - 1
|
||||||
get_logger().debug(f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}")
|
get_logger().debug(
|
||||||
|
f"Creating comment in MR {self.id_mr} with body {body} and position {pos_obj}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
self.mr.discussions.create({'body': body, 'position': pos_obj})
|
self.mr.discussions.create({'body': body, 'position': pos_obj})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
# fallback - create a general note on the file in the MR
|
# fallback - create a general note on the file in the MR
|
||||||
if 'suggestion_orig_location' in original_suggestion:
|
if 'suggestion_orig_location' in original_suggestion:
|
||||||
line_start = original_suggestion['suggestion_orig_location']['start_line']
|
line_start = original_suggestion['suggestion_orig_location'][
|
||||||
line_end = original_suggestion['suggestion_orig_location']['end_line']
|
'start_line'
|
||||||
|
]
|
||||||
|
line_end = original_suggestion['suggestion_orig_location'][
|
||||||
|
'end_line'
|
||||||
|
]
|
||||||
old_code_snippet = original_suggestion['prev_code_snippet']
|
old_code_snippet = original_suggestion['prev_code_snippet']
|
||||||
new_code_snippet = original_suggestion['new_code_snippet']
|
new_code_snippet = original_suggestion['new_code_snippet']
|
||||||
content = original_suggestion['suggestion_summary']
|
content = original_suggestion['suggestion_summary']
|
||||||
@ -287,36 +391,49 @@ class GitLabProvider(GitProvider):
|
|||||||
else:
|
else:
|
||||||
language = ''
|
language = ''
|
||||||
link = self.get_line_link(relevant_file, line_start, line_end)
|
link = self.get_line_link(relevant_file, line_start, line_end)
|
||||||
body_fallback =f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
|
body_fallback = (
|
||||||
body_fallback +=f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
|
f"**Suggestion:** {content} [{label}, importance: {score}]\n\n"
|
||||||
|
)
|
||||||
|
body_fallback += f"\n\n<details><summary>[{target_file.filename} [{line_start}-{line_end}]]({link}):</summary>\n\n"
|
||||||
body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`"
|
body_fallback += f"\n\n___\n\n`(Cannot implement directly - GitLab API allows committable suggestions strictly on MR diff lines)`"
|
||||||
body_fallback+="</details>\n\n"
|
body_fallback += "</details>\n\n"
|
||||||
diff_patch = difflib.unified_diff(old_code_snippet.split('\n'),
|
diff_patch = difflib.unified_diff(
|
||||||
new_code_snippet.split('\n'), n=999)
|
old_code_snippet.split('\n'),
|
||||||
|
new_code_snippet.split('\n'),
|
||||||
|
n=999,
|
||||||
|
)
|
||||||
patch_orig = "\n".join(diff_patch)
|
patch_orig = "\n".join(diff_patch)
|
||||||
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n')
|
||||||
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
diff_code = f"\n\n```diff\n{patch.rstrip()}\n```"
|
||||||
body_fallback += diff_code
|
body_fallback += diff_code
|
||||||
|
|
||||||
# Create a general note on the file in the MR
|
# Create a general note on the file in the MR
|
||||||
self.mr.notes.create({
|
self.mr.notes.create(
|
||||||
'body': body_fallback,
|
{
|
||||||
'position': {
|
'body': body_fallback,
|
||||||
'base_sha': diff.base_commit_sha,
|
'position': {
|
||||||
'start_sha': diff.start_commit_sha,
|
'base_sha': diff.base_commit_sha,
|
||||||
'head_sha': diff.head_commit_sha,
|
'start_sha': diff.start_commit_sha,
|
||||||
'position_type': 'text',
|
'head_sha': diff.head_commit_sha,
|
||||||
'file_path': f'{target_file.filename}',
|
'position_type': 'text',
|
||||||
|
'file_path': f'{target_file.filename}',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
get_logger().debug(f"Created fallback comment in MR {self.id_mr} with position {pos_obj}")
|
get_logger().debug(
|
||||||
|
f"Created fallback comment in MR {self.id_mr} with position {pos_obj}"
|
||||||
|
)
|
||||||
|
|
||||||
# get_logger().debug(
|
# get_logger().debug(
|
||||||
# f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)")
|
# f"Failed to create comment in MR {self.id_mr} with position {pos_obj} (probably not a '+' line)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to create comment in MR {self.id_mr}")
|
get_logger().exception(
|
||||||
|
f"Failed to create comment in MR {self.id_mr}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: str) -> Optional[dict]:
|
def get_relevant_diff(
|
||||||
|
self, relevant_file: str, relevant_line_in_file: str
|
||||||
|
) -> Optional[dict]:
|
||||||
changes = self.mr.changes() # Retrieve the changes for the merge request once
|
changes = self.mr.changes() # Retrieve the changes for the merge request once
|
||||||
if not changes:
|
if not changes:
|
||||||
get_logger().error('No changes found for the merge request.')
|
get_logger().error('No changes found for the merge request.')
|
||||||
@ -327,10 +444,14 @@ class GitLabProvider(GitProvider):
|
|||||||
return None
|
return None
|
||||||
for diff in all_diffs:
|
for diff in all_diffs:
|
||||||
for change in changes['changes']:
|
for change in changes['changes']:
|
||||||
if change['new_path'] == relevant_file and relevant_line_in_file in change['diff']:
|
if (
|
||||||
|
change['new_path'] == relevant_file
|
||||||
|
and relevant_line_in_file in change['diff']
|
||||||
|
):
|
||||||
return diff
|
return diff
|
||||||
get_logger().debug(
|
get_logger().debug(
|
||||||
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
|
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.'
|
||||||
|
)
|
||||||
return self.last_diff # fallback to last_diff if no relevant diff is found
|
return self.last_diff # fallback to last_diff if no relevant diff is found
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
@ -352,7 +473,7 @@ class GitLabProvider(GitProvider):
|
|||||||
if file.filename == relevant_file:
|
if file.filename == relevant_file:
|
||||||
target_file = file
|
target_file = file
|
||||||
break
|
break
|
||||||
range = relevant_lines_end - relevant_lines_start # no need to add 1
|
range = relevant_lines_end - relevant_lines_start # no need to add 1
|
||||||
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
|
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
|
||||||
lines = target_file.head_file.splitlines()
|
lines = target_file.head_file.splitlines()
|
||||||
relevant_line_in_file = lines[relevant_lines_start - 1]
|
relevant_line_in_file = lines[relevant_lines_start - 1]
|
||||||
@ -365,10 +486,21 @@ class GitLabProvider(GitProvider):
|
|||||||
found = True
|
found = True
|
||||||
edit_type = 'addition'
|
edit_type = 'addition'
|
||||||
|
|
||||||
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
self.send_inline_comment(
|
||||||
target_file, target_line_no, original_suggestion)
|
body,
|
||||||
|
edit_type,
|
||||||
|
found,
|
||||||
|
relevant_file,
|
||||||
|
relevant_line_in_file,
|
||||||
|
source_line_no,
|
||||||
|
target_file,
|
||||||
|
target_line_no,
|
||||||
|
original_suggestion,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}")
|
get_logger().exception(
|
||||||
|
f"Could not publish code suggestion:\nsuggestion: {suggestion}\nerror: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# note that we publish suggestions one-by-one. so, if one fails, the rest will still be published
|
# note that we publish suggestions one-by-one. so, if one fails, the rest will still be published
|
||||||
return True
|
return True
|
||||||
@ -382,8 +514,13 @@ class GitLabProvider(GitProvider):
|
|||||||
edit_type = self.get_edit_type(relevant_line_in_file)
|
edit_type = self.get_edit_type(relevant_line_in_file)
|
||||||
for file in self.get_diff_files():
|
for file in self.get_diff_files():
|
||||||
if file.filename == relevant_file:
|
if file.filename == relevant_file:
|
||||||
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
|
(
|
||||||
relevant_line_in_file)
|
edit_type,
|
||||||
|
found,
|
||||||
|
source_line_no,
|
||||||
|
target_file,
|
||||||
|
target_line_no,
|
||||||
|
) = self.find_in_file(file, relevant_line_in_file)
|
||||||
return edit_type, found, source_line_no, target_file, target_line_no
|
return edit_type, found, source_line_no, target_file, target_line_no
|
||||||
|
|
||||||
def find_in_file(self, file, relevant_line_in_file):
|
def find_in_file(self, file, relevant_line_in_file):
|
||||||
@ -414,7 +551,10 @@ class GitLabProvider(GitProvider):
|
|||||||
found = True
|
found = True
|
||||||
edit_type = self.get_edit_type(line)
|
edit_type = self.get_edit_type(line)
|
||||||
break
|
break
|
||||||
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:].lstrip() in line:
|
elif (
|
||||||
|
relevant_line_in_file[0] == '+'
|
||||||
|
and relevant_line_in_file[1:].lstrip() in line
|
||||||
|
):
|
||||||
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||||
# it's a context line
|
# it's a context line
|
||||||
found = True
|
found = True
|
||||||
@ -470,7 +610,11 @@ class GitLabProvider(GitProvider):
|
|||||||
|
|
||||||
def get_repo_settings(self):
|
def get_repo_settings(self):
|
||||||
try:
|
try:
|
||||||
contents = self.gl.projects.get(self.id_project).files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch).decode()
|
contents = (
|
||||||
|
self.gl.projects.get(self.id_project)
|
||||||
|
.files.get(file_path='.pr_agent.toml', ref=self.mr.target_branch)
|
||||||
|
.decode()
|
||||||
|
)
|
||||||
return contents
|
return contents
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
@ -478,7 +622,9 @@ class GitLabProvider(GitProvider):
|
|||||||
def get_workspace_name(self):
|
def get_workspace_name(self):
|
||||||
return self.id_project.split('/')[0]
|
return self.id_project.split('/')[0]
|
||||||
|
|
||||||
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]:
|
def add_eyes_reaction(
|
||||||
|
self, issue_comment_id: int, disable_eyes: bool = False
|
||||||
|
) -> Optional[int]:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool:
|
||||||
@ -489,7 +635,9 @@ class GitLabProvider(GitProvider):
|
|||||||
|
|
||||||
path_parts = parsed_url.path.strip('/').split('/')
|
path_parts = parsed_url.path.strip('/').split('/')
|
||||||
if 'merge_requests' not in path_parts:
|
if 'merge_requests' not in path_parts:
|
||||||
raise ValueError("The provided URL does not appear to be a GitLab merge request URL")
|
raise ValueError(
|
||||||
|
"The provided URL does not appear to be a GitLab merge request URL"
|
||||||
|
)
|
||||||
|
|
||||||
mr_index = path_parts.index('merge_requests')
|
mr_index = path_parts.index('merge_requests')
|
||||||
# Ensure there is an ID after 'merge_requests'
|
# Ensure there is an ID after 'merge_requests'
|
||||||
@ -541,8 +689,15 @@ class GitLabProvider(GitProvider):
|
|||||||
"""
|
"""
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
|
max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None)
|
||||||
try:
|
try:
|
||||||
commit_messages_list = [commit['message'] for commit in self.mr.commits()._list]
|
commit_messages_list = [
|
||||||
commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages_list)])
|
commit['message'] for commit in self.mr.commits()._list
|
||||||
|
]
|
||||||
|
commit_messages_str = "\n".join(
|
||||||
|
[
|
||||||
|
f"{i + 1}. {message}"
|
||||||
|
for i, message in enumerate(commit_messages_list)
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
commit_messages_str = ""
|
commit_messages_str = ""
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
@ -556,7 +711,12 @@ class GitLabProvider(GitProvider):
|
|||||||
except:
|
except:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str:
|
def get_line_link(
|
||||||
|
self,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_start: int,
|
||||||
|
relevant_line_end: int = None,
|
||||||
|
) -> str:
|
||||||
if relevant_line_start == -1:
|
if relevant_line_start == -1:
|
||||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads"
|
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads"
|
||||||
elif relevant_line_end:
|
elif relevant_line_end:
|
||||||
@ -565,7 +725,6 @@ class GitLabProvider(GitProvider):
|
|||||||
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}"
|
link = f"{self.gl.url}/{self.id_project}/-/blob/{self.mr.source_branch}/{relevant_file}?ref_type=heads#L{relevant_line_start}"
|
||||||
return link
|
return link
|
||||||
|
|
||||||
|
|
||||||
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
def generate_link_to_relevant_line_number(self, suggestion) -> str:
|
||||||
try:
|
try:
|
||||||
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
|
relevant_file = suggestion['relevant_file'].strip('`').strip("'").rstrip()
|
||||||
@ -573,8 +732,9 @@ class GitLabProvider(GitProvider):
|
|||||||
if not relevant_line_str:
|
if not relevant_line_str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
position, absolute_position = find_line_number_of_relevant_line_in_file \
|
position, absolute_position = find_line_number_of_relevant_line_in_file(
|
||||||
(self.diff_files, relevant_file, relevant_line_str)
|
self.diff_files, relevant_file, relevant_line_str
|
||||||
|
)
|
||||||
|
|
||||||
if absolute_position != -1:
|
if absolute_position != -1:
|
||||||
# link to right file only
|
# link to right file only
|
||||||
|
|||||||
@ -39,10 +39,16 @@ class LocalGitProvider(GitProvider):
|
|||||||
self._prepare_repo()
|
self._prepare_repo()
|
||||||
self.diff_files = None
|
self.diff_files = None
|
||||||
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
||||||
self.description_path = get_settings().get('local.description_path') \
|
self.description_path = (
|
||||||
if get_settings().get('local.description_path') is not None else self.repo_path / 'description.md'
|
get_settings().get('local.description_path')
|
||||||
self.review_path = get_settings().get('local.review_path') \
|
if get_settings().get('local.description_path') is not None
|
||||||
if get_settings().get('local.review_path') is not None else self.repo_path / 'review.md'
|
else self.repo_path / 'description.md'
|
||||||
|
)
|
||||||
|
self.review_path = (
|
||||||
|
get_settings().get('local.review_path')
|
||||||
|
if get_settings().get('local.review_path') is not None
|
||||||
|
else self.repo_path / 'review.md'
|
||||||
|
)
|
||||||
# inline code comments are not supported for local git repositories
|
# inline code comments are not supported for local git repositories
|
||||||
get_settings().pr_reviewer.inline_code_comments = False
|
get_settings().pr_reviewer.inline_code_comments = False
|
||||||
|
|
||||||
@ -52,30 +58,43 @@ class LocalGitProvider(GitProvider):
|
|||||||
"""
|
"""
|
||||||
get_logger().debug('Preparing repository for PR-mimic generation...')
|
get_logger().debug('Preparing repository for PR-mimic generation...')
|
||||||
if self.repo.is_dirty():
|
if self.repo.is_dirty():
|
||||||
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.')
|
raise ValueError(
|
||||||
|
'The repository is not in a clean state. Please commit or stash pending changes.'
|
||||||
|
)
|
||||||
if self.target_branch_name not in self.repo.heads:
|
if self.target_branch_name not in self.repo.heads:
|
||||||
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
|
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
|
||||||
|
|
||||||
def is_supported(self, capability: str) -> bool:
|
def is_supported(self, capability: str) -> bool:
|
||||||
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels',
|
if capability in [
|
||||||
'gfm_markdown']:
|
'get_issue_comments',
|
||||||
|
'create_inline_comment',
|
||||||
|
'publish_inline_comments',
|
||||||
|
'get_labels',
|
||||||
|
'gfm_markdown',
|
||||||
|
]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
diffs = self.repo.head.commit.diff(
|
diffs = self.repo.head.commit.diff(
|
||||||
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
|
self.repo.merge_base(
|
||||||
|
self.repo.head, self.repo.branches[self.target_branch_name]
|
||||||
|
),
|
||||||
create_patch=True,
|
create_patch=True,
|
||||||
R=True
|
R=True,
|
||||||
)
|
)
|
||||||
diff_files = []
|
diff_files = []
|
||||||
for diff_item in diffs:
|
for diff_item in diffs:
|
||||||
if diff_item.a_blob is not None:
|
if diff_item.a_blob is not None:
|
||||||
original_file_content_str = diff_item.a_blob.data_stream.read().decode('utf-8')
|
original_file_content_str = diff_item.a_blob.data_stream.read().decode(
|
||||||
|
'utf-8'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
original_file_content_str = "" # empty file
|
original_file_content_str = "" # empty file
|
||||||
if diff_item.b_blob is not None:
|
if diff_item.b_blob is not None:
|
||||||
new_file_content_str = diff_item.b_blob.data_stream.read().decode('utf-8')
|
new_file_content_str = diff_item.b_blob.data_stream.read().decode(
|
||||||
|
'utf-8'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new_file_content_str = "" # empty file
|
new_file_content_str = "" # empty file
|
||||||
edit_type = EDIT_TYPE.MODIFIED
|
edit_type = EDIT_TYPE.MODIFIED
|
||||||
@ -86,13 +105,16 @@ class LocalGitProvider(GitProvider):
|
|||||||
elif diff_item.renamed_file:
|
elif diff_item.renamed_file:
|
||||||
edit_type = EDIT_TYPE.RENAMED
|
edit_type = EDIT_TYPE.RENAMED
|
||||||
diff_files.append(
|
diff_files.append(
|
||||||
FilePatchInfo(original_file_content_str,
|
FilePatchInfo(
|
||||||
new_file_content_str,
|
original_file_content_str,
|
||||||
diff_item.diff.decode('utf-8'),
|
new_file_content_str,
|
||||||
diff_item.b_path,
|
diff_item.diff.decode('utf-8'),
|
||||||
edit_type=edit_type,
|
diff_item.b_path,
|
||||||
old_filename=None if diff_item.a_path == diff_item.b_path else diff_item.a_path
|
edit_type=edit_type,
|
||||||
)
|
old_filename=None
|
||||||
|
if diff_item.a_path == diff_item.b_path
|
||||||
|
else diff_item.a_path,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.diff_files = diff_files
|
self.diff_files = diff_files
|
||||||
return diff_files
|
return diff_files
|
||||||
@ -102,8 +124,10 @@ class LocalGitProvider(GitProvider):
|
|||||||
Returns a list of files with changes in the diff.
|
Returns a list of files with changes in the diff.
|
||||||
"""
|
"""
|
||||||
diff_index = self.repo.head.commit.diff(
|
diff_index = self.repo.head.commit.diff(
|
||||||
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
|
self.repo.merge_base(
|
||||||
R=True
|
self.repo.head, self.repo.branches[self.target_branch_name]
|
||||||
|
),
|
||||||
|
R=True,
|
||||||
)
|
)
|
||||||
# Get the list of changed files
|
# Get the list of changed files
|
||||||
diff_files = [item.a_path for item in diff_index]
|
diff_files = [item.a_path for item in diff_index]
|
||||||
@ -119,18 +143,37 @@ class LocalGitProvider(GitProvider):
|
|||||||
# Write the string to the file
|
# Write the string to the file
|
||||||
file.write(pr_comment)
|
file.write(pr_comment)
|
||||||
|
|
||||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None):
|
def publish_inline_comment(
|
||||||
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
|
self,
|
||||||
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_line_in_file: str,
|
||||||
|
original_suggestion=None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Publishing inline comments is not implemented for the local git provider'
|
||||||
|
)
|
||||||
|
|
||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
raise NotImplementedError('Publishing inline comments is not implemented for the local git provider')
|
raise NotImplementedError(
|
||||||
|
'Publishing inline comments is not implemented for the local git provider'
|
||||||
|
)
|
||||||
|
|
||||||
def publish_code_suggestion(self, body: str, relevant_file: str,
|
def publish_code_suggestion(
|
||||||
relevant_lines_start: int, relevant_lines_end: int):
|
self,
|
||||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
body: str,
|
||||||
|
relevant_file: str,
|
||||||
|
relevant_lines_start: int,
|
||||||
|
relevant_lines_end: int,
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Publishing code suggestions is not implemented for the local git provider'
|
||||||
|
)
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
raise NotImplementedError(
|
||||||
|
'Publishing code suggestions is not implemented for the local git provider'
|
||||||
|
)
|
||||||
|
|
||||||
def publish_labels(self, labels):
|
def publish_labels(self, labels):
|
||||||
pass # Not applicable to the local git provider, but required by the interface
|
pass # Not applicable to the local git provider, but required by the interface
|
||||||
@ -158,19 +201,31 @@ class LocalGitProvider(GitProvider):
|
|||||||
Calculate percentage of languages in repository. Used for hunk prioritisation.
|
Calculate percentage of languages in repository. Used for hunk prioritisation.
|
||||||
"""
|
"""
|
||||||
# Get all files in repository
|
# Get all files in repository
|
||||||
filepaths = [Path(item.path) for item in self.repo.tree().traverse() if item.type == 'blob']
|
filepaths = [
|
||||||
|
Path(item.path)
|
||||||
|
for item in self.repo.tree().traverse()
|
||||||
|
if item.type == 'blob'
|
||||||
|
]
|
||||||
# Identify language by file extension and count
|
# Identify language by file extension and count
|
||||||
lang_count = Counter(ext.lstrip('.') for filepath in filepaths for ext in [filepath.suffix.lower()])
|
lang_count = Counter(
|
||||||
|
ext.lstrip('.')
|
||||||
|
for filepath in filepaths
|
||||||
|
for ext in [filepath.suffix.lower()]
|
||||||
|
)
|
||||||
# Convert counts to percentages
|
# Convert counts to percentages
|
||||||
total_files = len(filepaths)
|
total_files = len(filepaths)
|
||||||
lang_percentage = {lang: count / total_files * 100 for lang, count in lang_count.items()}
|
lang_percentage = {
|
||||||
|
lang: count / total_files * 100 for lang, count in lang_count.items()
|
||||||
|
}
|
||||||
return lang_percentage
|
return lang_percentage
|
||||||
|
|
||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.repo.head
|
return self.repo.head
|
||||||
|
|
||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
return -1 # Not used anywhere for the local provider, but required by the interface
|
return (
|
||||||
|
-1
|
||||||
|
) # Not used anywhere for the local provider, but required by the interface
|
||||||
|
|
||||||
def get_pr_description_full(self):
|
def get_pr_description_full(self):
|
||||||
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
||||||
@ -186,7 +241,11 @@ class LocalGitProvider(GitProvider):
|
|||||||
return self.head_branch_name
|
return self.head_branch_name
|
||||||
|
|
||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
|
raise NotImplementedError(
|
||||||
|
'Getting issue comments is not implemented for the local git provider'
|
||||||
|
)
|
||||||
|
|
||||||
def get_pr_labels(self, update=False):
|
def get_pr_labels(self, update=False):
|
||||||
raise NotImplementedError('Getting labels is not implemented for the local git provider')
|
raise NotImplementedError(
|
||||||
|
'Getting labels is not implemented for the local git provider'
|
||||||
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from dynaconf import Dynaconf
|
|||||||
from starlette_context import context
|
from starlette_context import context
|
||||||
|
|
||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
from utils.pr_agent.git_providers import (get_git_provider_with_context)
|
from utils.pr_agent.git_providers import get_git_provider_with_context
|
||||||
from utils.pr_agent.log import get_logger
|
from utils.pr_agent.log import get_logger
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +20,9 @@ def apply_repo_settings(pr_url):
|
|||||||
except Exception:
|
except Exception:
|
||||||
repo_settings = None
|
repo_settings = None
|
||||||
pass
|
pass
|
||||||
if repo_settings is None: # None is different from "", which is a valid value
|
if (
|
||||||
|
repo_settings is None
|
||||||
|
): # None is different from "", which is a valid value
|
||||||
repo_settings = git_provider.get_repo_settings()
|
repo_settings = git_provider.get_repo_settings()
|
||||||
try:
|
try:
|
||||||
context["repo_settings"] = repo_settings
|
context["repo_settings"] = repo_settings
|
||||||
@ -36,15 +38,25 @@ def apply_repo_settings(pr_url):
|
|||||||
os.write(fd, repo_settings)
|
os.write(fd, repo_settings)
|
||||||
new_settings = Dynaconf(settings_files=[repo_settings_file])
|
new_settings = Dynaconf(settings_files=[repo_settings_file])
|
||||||
for section, contents in new_settings.as_dict().items():
|
for section, contents in new_settings.as_dict().items():
|
||||||
section_dict = copy.deepcopy(get_settings().as_dict().get(section, {}))
|
section_dict = copy.deepcopy(
|
||||||
|
get_settings().as_dict().get(section, {})
|
||||||
|
)
|
||||||
for key, value in contents.items():
|
for key, value in contents.items():
|
||||||
section_dict[key] = value
|
section_dict[key] = value
|
||||||
get_settings().unset(section)
|
get_settings().unset(section)
|
||||||
get_settings().set(section, section_dict, merge=False)
|
get_settings().set(section, section_dict, merge=False)
|
||||||
get_logger().info(f"Applying repo settings:\n{new_settings.as_dict()}")
|
get_logger().info(
|
||||||
|
f"Applying repo settings:\n{new_settings.as_dict()}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to apply repo {category} settings, error: {str(e)}")
|
get_logger().warning(
|
||||||
error_local = {'error': str(e), 'settings': repo_settings, 'category': category}
|
f"Failed to apply repo {category} settings, error: {str(e)}"
|
||||||
|
)
|
||||||
|
error_local = {
|
||||||
|
'error': str(e),
|
||||||
|
'settings': repo_settings,
|
||||||
|
'category': category,
|
||||||
|
}
|
||||||
|
|
||||||
if error_local:
|
if error_local:
|
||||||
handle_configurations_errors([error_local], git_provider)
|
handle_configurations_errors([error_local], git_provider)
|
||||||
@ -55,7 +67,10 @@ def apply_repo_settings(pr_url):
|
|||||||
try:
|
try:
|
||||||
os.remove(repo_settings_file)
|
os.remove(repo_settings_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e)
|
get_logger().error(
|
||||||
|
f"Failed to remove temporary settings file {repo_settings_file}",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
# enable switching models with a short definition
|
# enable switching models with a short definition
|
||||||
if get_settings().config.model.lower() == 'claude-3-5-sonnet':
|
if get_settings().config.model.lower() == 'claude-3-5-sonnet':
|
||||||
@ -79,13 +94,18 @@ def handle_configurations_errors(config_errors, git_provider):
|
|||||||
body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>"
|
body += f"\n\n<details><summary>配置内容:</summary>\n\n```toml\n{configuration_file_content}\n```\n\n</details>"
|
||||||
else:
|
else:
|
||||||
body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n"
|
body += f"\n\n**配置内容:**\n\n```toml\n{configuration_file_content}\n```\n\n"
|
||||||
get_logger().warning(f"Sending a 'configuration error' comment to the PR", artifact={'body': body})
|
get_logger().warning(
|
||||||
|
f"Sending a 'configuration error' comment to the PR",
|
||||||
|
artifact={'body': body},
|
||||||
|
)
|
||||||
# git_provider.publish_comment(body)
|
# git_provider.publish_comment(body)
|
||||||
if hasattr(git_provider, 'publish_persistent_comment'):
|
if hasattr(git_provider, 'publish_persistent_comment'):
|
||||||
git_provider.publish_persistent_comment(body,
|
git_provider.publish_persistent_comment(
|
||||||
initial_header=header,
|
body,
|
||||||
update_header=False,
|
initial_header=header,
|
||||||
final_update_message=False)
|
update_header=False,
|
||||||
|
final_update_message=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
git_provider.publish_comment(body)
|
git_provider.publish_comment(body)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
from utils.pr_agent.identity_providers.default_identity_provider import \
|
from utils.pr_agent.identity_providers.default_identity_provider import (
|
||||||
DefaultIdentityProvider
|
DefaultIdentityProvider,
|
||||||
|
)
|
||||||
|
|
||||||
_IDENTITY_PROVIDERS = {
|
_IDENTITY_PROVIDERS = {'default': DefaultIdentityProvider}
|
||||||
'default': DefaultIdentityProvider
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_identity_provider():
|
def get_identity_provider():
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from utils.pr_agent.identity_providers.identity_provider import (Eligibility,
|
from utils.pr_agent.identity_providers.identity_provider import (
|
||||||
IdentityProvider)
|
Eligibility,
|
||||||
|
IdentityProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DefaultIdentityProvider(IdentityProvider):
|
class DefaultIdentityProvider(IdentityProvider):
|
||||||
|
|||||||
@ -30,7 +30,9 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
|
|||||||
if type(level) is not int:
|
if type(level) is not int:
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
|
|
||||||
if fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0": # better debugging github_app
|
if (
|
||||||
|
fmt == LoggingFormat.JSON and os.getenv("LOG_SANE", "0").lower() == "0"
|
||||||
|
): # better debugging github_app
|
||||||
logger.remove(None)
|
logger.remove(None)
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
@ -40,7 +42,7 @@ def setup_logger(level: str = "INFO", fmt: LoggingFormat = LoggingFormat.CONSOLE
|
|||||||
colorize=False,
|
colorize=False,
|
||||||
serialize=True,
|
serialize=True,
|
||||||
)
|
)
|
||||||
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
|
elif fmt == LoggingFormat.CONSOLE: # does not print the 'extra' fields
|
||||||
logger.remove(None)
|
logger.remove(None)
|
||||||
logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter)
|
logger.add(sys.stdout, level=level, colorize=True, filter=inv_analytics_filter)
|
||||||
|
|
||||||
|
|||||||
@ -8,10 +8,14 @@ def get_secret_provider():
|
|||||||
provider_id = get_settings().config.secret_provider
|
provider_id = get_settings().config.secret_provider
|
||||||
if provider_id == 'google_cloud_storage':
|
if provider_id == 'google_cloud_storage':
|
||||||
try:
|
try:
|
||||||
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import \
|
from utils.pr_agent.secret_providers.google_cloud_storage_secret_provider import (
|
||||||
GoogleCloudStorageSecretProvider
|
GoogleCloudStorageSecretProvider,
|
||||||
|
)
|
||||||
|
|
||||||
return GoogleCloudStorageSecretProvider()
|
return GoogleCloudStorageSecretProvider()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e
|
raise ValueError(
|
||||||
|
f"Failed to initialize google_cloud_storage secret provider {provider_id}"
|
||||||
|
) from e
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown SECRET_PROVIDER")
|
raise ValueError("Unknown SECRET_PROVIDER")
|
||||||
|
|||||||
@ -9,12 +9,15 @@ from utils.pr_agent.secret_providers.secret_provider import SecretProvider
|
|||||||
class GoogleCloudStorageSecretProvider(SecretProvider):
|
class GoogleCloudStorageSecretProvider(SecretProvider):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
self.client = storage.Client.from_service_account_info(ujson.loads(get_settings().google_cloud_storage.
|
self.client = storage.Client.from_service_account_info(
|
||||||
service_account))
|
ujson.loads(get_settings().google_cloud_storage.service_account)
|
||||||
|
)
|
||||||
self.bucket_name = get_settings().google_cloud_storage.bucket_name
|
self.bucket_name = get_settings().google_cloud_storage.bucket_name
|
||||||
self.bucket = self.client.bucket(self.bucket_name)
|
self.bucket = self.client.bucket(self.bucket_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to initialize Google Cloud Storage Secret Provider: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to initialize Google Cloud Storage Secret Provider: {e}"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_secret(self, secret_name: str) -> str:
|
def get_secret(self, secret_name: str) -> str:
|
||||||
@ -22,7 +25,9 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
|
|||||||
blob = self.bucket.blob(secret_name)
|
blob = self.bucket.blob(secret_name)
|
||||||
return blob.download_as_string()
|
return blob.download_as_string()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to get secret {secret_name} from Google Cloud Storage: {e}")
|
get_logger().warning(
|
||||||
|
f"Failed to get secret {secret_name} from Google Cloud Storage: {e}"
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def store_secret(self, secret_name: str, secret_value: str):
|
def store_secret(self, secret_name: str, secret_value: str):
|
||||||
@ -30,5 +35,7 @@ class GoogleCloudStorageSecretProvider(SecretProvider):
|
|||||||
blob = self.bucket.blob(secret_name)
|
blob = self.bucket.blob(secret_name)
|
||||||
blob.upload_from_string(secret_value)
|
blob.upload_from_string(secret_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to store secret {secret_name} in Google Cloud Storage: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to store secret {secret_name} in Google Cloud Storage: {e}"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
class SecretProvider(ABC):
|
class SecretProvider(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_secret(self, secret_name: str) -> str:
|
def get_secret(self, secret_name: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -33,6 +33,7 @@ azure_devops_server = get_settings().get("azure_devops_server")
|
|||||||
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
|
WEBHOOK_USERNAME = azure_devops_server.get("webhook_username")
|
||||||
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
|
WEBHOOK_PASSWORD = azure_devops_server.get("webhook_password")
|
||||||
|
|
||||||
|
|
||||||
def handle_request(
|
def handle_request(
|
||||||
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
|
background_tasks: BackgroundTasks, url: str, body: str, log_context: dict
|
||||||
):
|
):
|
||||||
@ -52,20 +53,27 @@ def handle_request(
|
|||||||
# currently only basic auth is supported with azure webhooks
|
# currently only basic auth is supported with azure webhooks
|
||||||
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
|
# for this reason, https must be enabled to ensure the credentials are not sent in clear text
|
||||||
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
|
def authorize(credentials: HTTPBasicCredentials = Depends(security)):
|
||||||
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
|
is_user_ok = secrets.compare_digest(credentials.username, WEBHOOK_USERNAME)
|
||||||
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
|
is_pass_ok = secrets.compare_digest(credentials.password, WEBHOOK_PASSWORD)
|
||||||
if not (is_user_ok and is_pass_ok):
|
if not (is_user_ok and is_pass_ok):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail='Incorrect username or password.',
|
detail='Incorrect username or password.',
|
||||||
headers={'WWW-Authenticate': 'Basic'},
|
headers={'WWW-Authenticate': 'Basic'},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _perform_commands_azure(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict):
|
async def _perform_commands_azure(
|
||||||
|
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict
|
||||||
|
):
|
||||||
apply_repo_settings(api_url)
|
apply_repo_settings(api_url)
|
||||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
if (
|
||||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
|
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||||
|
): # auto commands for PR, and auto feedback is disabled
|
||||||
|
get_logger().info(
|
||||||
|
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
|
||||||
|
**log_context,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
commands = get_settings().get(f"azure_devops_server.{commands_conf}")
|
commands = get_settings().get(f"azure_devops_server.{commands_conf}")
|
||||||
get_settings().set("config.is_auto_command", True)
|
get_settings().set("config.is_auto_command", True)
|
||||||
@ -92,22 +100,38 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
actions = []
|
actions = []
|
||||||
if data["eventType"] == "git.pullrequest.created":
|
if data["eventType"] == "git.pullrequest.created":
|
||||||
# API V1 (latest)
|
# API V1 (latest)
|
||||||
pr_url = unquote(data["resource"]["_links"]["web"]["href"].replace("_apis/git/repositories", "_git"))
|
pr_url = unquote(
|
||||||
|
data["resource"]["_links"]["web"]["href"].replace(
|
||||||
|
"_apis/git/repositories", "_git"
|
||||||
|
)
|
||||||
|
)
|
||||||
log_context["event"] = data["eventType"]
|
log_context["event"] = data["eventType"]
|
||||||
log_context["api_url"] = pr_url
|
log_context["api_url"] = pr_url
|
||||||
await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context)
|
await _perform_commands_azure("pr_commands", PRAgent(), pr_url, log_context)
|
||||||
return
|
return
|
||||||
elif data["eventType"] == "ms.vss-code.git-pullrequest-comment-event" and "content" in data["resource"]["comment"]:
|
elif (
|
||||||
|
data["eventType"] == "ms.vss-code.git-pullrequest-comment-event"
|
||||||
|
and "content" in data["resource"]["comment"]
|
||||||
|
):
|
||||||
if available_commands_rgx.match(data["resource"]["comment"]["content"]):
|
if available_commands_rgx.match(data["resource"]["comment"]["content"]):
|
||||||
if(data["resourceVersion"] == "2.0"):
|
if data["resourceVersion"] == "2.0":
|
||||||
repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
|
repo = data["resource"]["pullRequest"]["repository"]["webUrl"]
|
||||||
pr_url = unquote(f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}')
|
pr_url = unquote(
|
||||||
|
f'{repo}/pullrequest/{data["resource"]["pullRequest"]["pullRequestId"]}'
|
||||||
|
)
|
||||||
actions = [data["resource"]["comment"]["content"]]
|
actions = [data["resource"]["comment"]["content"]]
|
||||||
else:
|
else:
|
||||||
# API V1 not supported as it does not contain the PR URL
|
# API V1 not supported as it does not contain the PR URL
|
||||||
return JSONResponse(
|
return (
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
JSONResponse(
|
||||||
content=json.dumps({"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"})),
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"message": "version 1.0 webhook for Azure Devops PR comment is not supported. please upgrade to version 2.0"
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@ -132,17 +156,21 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
content=json.dumps({"message": "Internal server error"}),
|
content=json.dumps({"message": "Internal server error"}),
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_202_ACCEPTED, content=jsonable_encoder({"message": "webhook triggered successfully"})
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
content=jsonable_encoder({"message": "webhook triggered successfully"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
|
app = FastAPI(middleware=[Middleware(RawContextMiddleware)])
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
start()
|
start()
|
||||||
|
|||||||
@ -27,7 +27,9 @@ from utils.pr_agent.secret_providers import get_secret_provider
|
|||||||
|
|
||||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
secret_provider = (
|
||||||
|
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_bearer_token(shared_secret: str, client_key: str):
|
async def get_bearer_token(shared_secret: str, client_key: str):
|
||||||
@ -44,12 +46,12 @@ async def get_bearer_token(shared_secret: str, client_key: str):
|
|||||||
"exp": now + 240,
|
"exp": now + 240,
|
||||||
"qsh": qsh,
|
"qsh": qsh,
|
||||||
"sub": client_key,
|
"sub": client_key,
|
||||||
}
|
}
|
||||||
token = jwt.encode(payload, shared_secret, algorithm="HS256")
|
token = jwt.encode(payload, shared_secret, algorithm="HS256")
|
||||||
payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt'
|
payload = 'grant_type=urn%3Abitbucket%3Aoauth2%3Ajwt'
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f'JWT {token}',
|
'Authorization': f'JWT {token}',
|
||||||
'Content-Type': 'application/x-www-form-urlencoded'
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
}
|
}
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
bearer_token = response.json()["access_token"]
|
bearer_token = response.json()["access_token"]
|
||||||
@ -58,6 +60,7 @@ async def get_bearer_token(shared_secret: str, client_key: str):
|
|||||||
get_logger().error(f"Failed to get bearer token: {e}")
|
get_logger().error(f"Failed to get bearer token: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def handle_manifest(request: Request, response: Response):
|
async def handle_manifest(request: Request, response: Response):
|
||||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
@ -66,7 +69,9 @@ async def handle_manifest(request: Request, response: Response):
|
|||||||
manifest = manifest.replace("app_key", get_settings().bitbucket.app_key)
|
manifest = manifest.replace("app_key", get_settings().bitbucket.app_key)
|
||||||
manifest = manifest.replace("base_url", get_settings().bitbucket.base_url)
|
manifest = manifest.replace("base_url", get_settings().bitbucket.base_url)
|
||||||
except:
|
except:
|
||||||
get_logger().error("Failed to replace api_key in Bitbucket manifest, trying to continue")
|
get_logger().error(
|
||||||
|
"Failed to replace api_key in Bitbucket manifest, trying to continue"
|
||||||
|
)
|
||||||
manifest_obj = json.loads(manifest)
|
manifest_obj = json.loads(manifest)
|
||||||
return JSONResponse(manifest_obj)
|
return JSONResponse(manifest_obj)
|
||||||
|
|
||||||
@ -83,10 +88,16 @@ def _get_username(data):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def _perform_commands_bitbucket(commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict):
|
async def _perform_commands_bitbucket(
|
||||||
|
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
|
||||||
|
):
|
||||||
apply_repo_settings(api_url)
|
apply_repo_settings(api_url)
|
||||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
if (
|
||||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
|
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||||
|
): # auto commands for PR, and auto feedback is disabled
|
||||||
|
get_logger().info(
|
||||||
|
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
if data.get("event", "") == "pullrequest:created":
|
if data.get("event", "") == "pullrequest:created":
|
||||||
if not should_process_pr_logic(data):
|
if not should_process_pr_logic(data):
|
||||||
@ -132,7 +143,9 @@ def should_process_pr_logic(data) -> bool:
|
|||||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||||
if ignore_pr_users and sender:
|
if ignore_pr_users and sender:
|
||||||
if sender in ignore_pr_users:
|
if sender in ignore_pr_users:
|
||||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
|
get_logger().info(
|
||||||
|
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# logic to ignore PRs with specific titles
|
# logic to ignore PRs with specific titles
|
||||||
@ -140,20 +153,34 @@ def should_process_pr_logic(data) -> bool:
|
|||||||
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||||
if not isinstance(ignore_pr_title_re, list):
|
if not isinstance(ignore_pr_title_re, list):
|
||||||
ignore_pr_title_re = [ignore_pr_title_re]
|
ignore_pr_title_re = [ignore_pr_title_re]
|
||||||
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
|
if ignore_pr_title_re and any(
|
||||||
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
|
re.search(regex, title) for regex in ignore_pr_title_re
|
||||||
|
):
|
||||||
|
get_logger().info(
|
||||||
|
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
ignore_pr_source_branches = get_settings().get(
|
||||||
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
|
||||||
if (ignore_pr_source_branches or ignore_pr_target_branches):
|
)
|
||||||
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
|
ignore_pr_target_branches = get_settings().get(
|
||||||
|
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
|
||||||
|
)
|
||||||
|
if ignore_pr_source_branches or ignore_pr_target_branches:
|
||||||
|
if any(
|
||||||
|
re.search(regex, source_branch) for regex in ignore_pr_source_branches
|
||||||
|
):
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
|
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
|
if any(
|
||||||
|
re.search(regex, target_branch) for regex in ignore_pr_target_branches
|
||||||
|
):
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
|
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||||
@ -195,7 +222,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
|||||||
client_key = claims["iss"]
|
client_key = claims["iss"]
|
||||||
secrets = json.loads(secret_provider.get_secret(client_key))
|
secrets = json.loads(secret_provider.get_secret(client_key))
|
||||||
shared_secret = secrets["shared_secret"]
|
shared_secret = secrets["shared_secret"]
|
||||||
jwt.decode(input_jwt, shared_secret, audience=client_key, algorithms=["HS256"])
|
jwt.decode(
|
||||||
|
input_jwt, shared_secret, audience=client_key, algorithms=["HS256"]
|
||||||
|
)
|
||||||
bearer_token = await get_bearer_token(shared_secret, client_key)
|
bearer_token = await get_bearer_token(shared_secret, client_key)
|
||||||
context['bitbucket_bearer_token'] = bearer_token
|
context['bitbucket_bearer_token'] = bearer_token
|
||||||
context["settings"] = copy.deepcopy(global_settings)
|
context["settings"] = copy.deepcopy(global_settings)
|
||||||
@ -208,28 +237,41 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
|||||||
if pr_url:
|
if pr_url:
|
||||||
with get_logger().contextualize(**log_context):
|
with get_logger().contextualize(**log_context):
|
||||||
apply_repo_settings(pr_url)
|
apply_repo_settings(pr_url)
|
||||||
if get_identity_provider().verify_eligibility("bitbucket",
|
if (
|
||||||
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
|
get_identity_provider().verify_eligibility(
|
||||||
|
"bitbucket", sender_id, pr_url
|
||||||
|
)
|
||||||
|
is not Eligibility.NOT_ELIGIBLE
|
||||||
|
):
|
||||||
if get_settings().get("bitbucket_app.pr_commands"):
|
if get_settings().get("bitbucket_app.pr_commands"):
|
||||||
await _perform_commands_bitbucket("pr_commands", PRAgent(), pr_url, log_context, data)
|
await _perform_commands_bitbucket(
|
||||||
|
"pr_commands", PRAgent(), pr_url, log_context, data
|
||||||
|
)
|
||||||
elif event == "pullrequest:comment_created":
|
elif event == "pullrequest:comment_created":
|
||||||
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
||||||
log_context["api_url"] = pr_url
|
log_context["api_url"] = pr_url
|
||||||
log_context["event"] = "comment"
|
log_context["event"] = "comment"
|
||||||
comment_body = data["data"]["comment"]["content"]["raw"]
|
comment_body = data["data"]["comment"]["content"]["raw"]
|
||||||
with get_logger().contextualize(**log_context):
|
with get_logger().contextualize(**log_context):
|
||||||
if get_identity_provider().verify_eligibility("bitbucket",
|
if (
|
||||||
sender_id, pr_url) is not Eligibility.NOT_ELIGIBLE:
|
get_identity_provider().verify_eligibility(
|
||||||
|
"bitbucket", sender_id, pr_url
|
||||||
|
)
|
||||||
|
is not Eligibility.NOT_ELIGIBLE
|
||||||
|
):
|
||||||
await agent.handle_request(pr_url, comment_body)
|
await agent.handle_request(pr_url, comment_body)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to handle webhook: {e}")
|
get_logger().error(f"Failed to handle webhook: {e}")
|
||||||
|
|
||||||
background_tasks.add_task(inner)
|
background_tasks.add_task(inner)
|
||||||
return "OK"
|
return "OK"
|
||||||
|
|
||||||
|
|
||||||
@router.get("/webhook")
|
@router.get("/webhook")
|
||||||
async def handle_github_webhooks(request: Request, response: Response):
|
async def handle_github_webhooks(request: Request, response: Response):
|
||||||
return "Webhook server online!"
|
return "Webhook server online!"
|
||||||
|
|
||||||
|
|
||||||
@router.post("/installed")
|
@router.post("/installed")
|
||||||
async def handle_installed_webhooks(request: Request, response: Response):
|
async def handle_installed_webhooks(request: Request, response: Response):
|
||||||
try:
|
try:
|
||||||
@ -240,15 +282,13 @@ async def handle_installed_webhooks(request: Request, response: Response):
|
|||||||
shared_secret = data["sharedSecret"]
|
shared_secret = data["sharedSecret"]
|
||||||
client_key = data["clientKey"]
|
client_key = data["clientKey"]
|
||||||
username = data["principal"]["username"]
|
username = data["principal"]["username"]
|
||||||
secrets = {
|
secrets = {"shared_secret": shared_secret, "client_key": client_key}
|
||||||
"shared_secret": shared_secret,
|
|
||||||
"client_key": client_key
|
|
||||||
}
|
|
||||||
secret_provider.store_secret(username, json.dumps(secrets))
|
secret_provider.store_secret(username, json.dumps(secrets))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to register user: {e}")
|
get_logger().error(f"Failed to register user: {e}")
|
||||||
return JSONResponse({"error": "Unable to register user"}, status_code=500)
|
return JSONResponse({"error": "Unable to register user"}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/uninstalled")
|
@router.post("/uninstalled")
|
||||||
async def handle_uninstalled_webhooks(request: Request, response: Response):
|
async def handle_uninstalled_webhooks(request: Request, response: Response):
|
||||||
get_logger().info("handle_uninstalled_webhooks")
|
get_logger().info("handle_uninstalled_webhooks")
|
||||||
|
|||||||
@ -40,10 +40,12 @@ def handle_request(
|
|||||||
|
|
||||||
background_tasks.add_task(inner)
|
background_tasks.add_task(inner)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/")
|
@router.post("/")
|
||||||
async def redirect_to_webhook():
|
async def redirect_to_webhook():
|
||||||
return RedirectResponse(url="/webhook")
|
return RedirectResponse(url="/webhook")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/webhook")
|
@router.post("/webhook")
|
||||||
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
||||||
log_context = {"server_type": "bitbucket_server"}
|
log_context = {"server_type": "bitbucket_server"}
|
||||||
@ -55,7 +57,8 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
if body_bytes.decode('utf-8') == '{"test": true}':
|
if body_bytes.decode('utf-8') == '{"test": true}':
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "connection test successful"})
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder({"message": "connection test successful"}),
|
||||||
)
|
)
|
||||||
signature_header = request.headers.get("x-hub-signature", None)
|
signature_header = request.headers.get("x-hub-signature", None)
|
||||||
verify_signature(body_bytes, webhook_secret, signature_header)
|
verify_signature(body_bytes, webhook_secret, signature_header)
|
||||||
@ -73,11 +76,18 @@ async def handle_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
|
|
||||||
if data["eventKey"] == "pr:opened":
|
if data["eventKey"] == "pr:opened":
|
||||||
apply_repo_settings(pr_url)
|
apply_repo_settings(pr_url)
|
||||||
if get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
if (
|
||||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {pr_url}", **log_context)
|
get_settings().config.disable_auto_feedback
|
||||||
|
): # auto commands for PR, and auto feedback is disabled
|
||||||
|
get_logger().info(
|
||||||
|
f"Auto feedback is disabled, skipping auto commands for PR {pr_url}",
|
||||||
|
**log_context,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
get_settings().set("config.is_auto_command", True)
|
get_settings().set("config.is_auto_command", True)
|
||||||
commands_to_run.extend(_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS'))
|
commands_to_run.extend(
|
||||||
|
_get_commands_list_from_settings('BITBUCKET_SERVER.PR_COMMANDS')
|
||||||
|
)
|
||||||
elif data["eventKey"] == "pr:comment:added":
|
elif data["eventKey"] == "pr:comment:added":
|
||||||
commands_to_run.append(data["comment"]["text"])
|
commands_to_run.append(data["comment"]["text"])
|
||||||
else:
|
else:
|
||||||
@ -116,6 +126,7 @@ async def _run_commands_sequentially(commands: List[str], url: str, log_context:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to handle command: {command} , error: {e}")
|
get_logger().error(f"Failed to handle command: {command} , error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _process_command(command: str, url) -> str:
|
def _process_command(command: str, url) -> str:
|
||||||
# don't think we need this
|
# don't think we need this
|
||||||
apply_repo_settings(url)
|
apply_repo_settings(url)
|
||||||
@ -142,11 +153,13 @@ def _to_list(command_string: str) -> list:
|
|||||||
raise ValueError(f"Invalid command string: {e}")
|
raise ValueError(f"Invalid command string: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _get_commands_list_from_settings(setting_key:str ) -> list:
|
def _get_commands_list_from_settings(setting_key: str) -> list:
|
||||||
try:
|
try:
|
||||||
return get_settings().get(setting_key, [])
|
return get_settings().get(setting_key, [])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
get_logger().error(f"Failed to get commands list from settings {setting_key}: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to get commands list from settings {setting_key}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
|
|||||||
@ -40,12 +40,10 @@ async def handle_gerrit_request(action: Action, item: Item):
|
|||||||
if action == Action.ask:
|
if action == Action.ask:
|
||||||
if not item.msg:
|
if not item.msg:
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
status_code=400,
|
status_code=400, detail="msg is required for ask command"
|
||||||
detail="msg is required for ask command"
|
|
||||||
)
|
)
|
||||||
await PRAgent().handle_request(
|
await PRAgent().handle_request(
|
||||||
f"{item.project}:{item.refspec}",
|
f"{item.project}:{item.refspec}", f"/{item.msg.strip()}"
|
||||||
f"/{item.msg.strip()}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,12 @@ def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str,
|
|||||||
try:
|
try:
|
||||||
value = get_settings().get(key, default)
|
value = get_settings().get(key, default)
|
||||||
except AttributeError: # TBD still need to debug why this happens on GitHub Actions
|
except AttributeError: # TBD still need to debug why this happens on GitHub Actions
|
||||||
value = os.getenv(key, None) or os.getenv(key.upper(), None) or os.getenv(key.lower(), None) or default
|
value = (
|
||||||
|
os.getenv(key, None)
|
||||||
|
or os.getenv(key.upper(), None)
|
||||||
|
or os.getenv(key.lower(), None)
|
||||||
|
or default
|
||||||
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@ -76,16 +81,24 @@ async def run_action():
|
|||||||
pr_url = event_payload.get("pull_request", {}).get("html_url")
|
pr_url = event_payload.get("pull_request", {}).get("html_url")
|
||||||
if pr_url:
|
if pr_url:
|
||||||
apply_repo_settings(pr_url)
|
apply_repo_settings(pr_url)
|
||||||
get_logger().info(f"enable_custom_labels: {get_settings().config.enable_custom_labels}")
|
get_logger().info(
|
||||||
|
f"enable_custom_labels: {get_settings().config.enable_custom_labels}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().info(f"github action: failed to apply repo settings: {e}")
|
get_logger().info(f"github action: failed to apply repo settings: {e}")
|
||||||
|
|
||||||
# Handle pull request opened event
|
# Handle pull request opened event
|
||||||
if GITHUB_EVENT_NAME == "pull_request" or GITHUB_EVENT_NAME == "pull_request_target":
|
if (
|
||||||
|
GITHUB_EVENT_NAME == "pull_request"
|
||||||
|
or GITHUB_EVENT_NAME == "pull_request_target"
|
||||||
|
):
|
||||||
action = event_payload.get("action")
|
action = event_payload.get("action")
|
||||||
|
|
||||||
# Retrieve the list of actions from the configuration
|
# Retrieve the list of actions from the configuration
|
||||||
pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"])
|
pr_actions = get_settings().get(
|
||||||
|
"GITHUB_ACTION_CONFIG.PR_ACTIONS",
|
||||||
|
["opened", "reopened", "ready_for_review", "review_requested"],
|
||||||
|
)
|
||||||
|
|
||||||
if action in pr_actions:
|
if action in pr_actions:
|
||||||
pr_url = event_payload.get("pull_request", {}).get("url")
|
pr_url = event_payload.get("pull_request", {}).get("url")
|
||||||
@ -93,18 +106,30 @@ async def run_action():
|
|||||||
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
|
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
|
||||||
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
|
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
|
||||||
if auto_review is None:
|
if auto_review is None:
|
||||||
auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None)
|
auto_review = get_setting_or_env(
|
||||||
|
"GITHUB_ACTION_CONFIG.AUTO_REVIEW", None
|
||||||
|
)
|
||||||
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
|
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
|
||||||
if auto_describe is None:
|
if auto_describe is None:
|
||||||
auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None)
|
auto_describe = get_setting_or_env(
|
||||||
|
"GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None
|
||||||
|
)
|
||||||
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
|
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
|
||||||
if auto_improve is None:
|
if auto_improve is None:
|
||||||
auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None)
|
auto_improve = get_setting_or_env(
|
||||||
|
"GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None
|
||||||
|
)
|
||||||
|
|
||||||
# Set the configuration for auto actions
|
# Set the configuration for auto actions
|
||||||
get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto
|
get_settings().config.is_auto_command = (
|
||||||
get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled
|
True # Set the flag to indicate that the command is auto
|
||||||
get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}")
|
)
|
||||||
|
get_settings().pr_description.final_update_message = (
|
||||||
|
False # No final update message when auto_describe is enabled
|
||||||
|
)
|
||||||
|
get_logger().info(
|
||||||
|
f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}"
|
||||||
|
)
|
||||||
|
|
||||||
# invoke by default all three tools
|
# invoke by default all three tools
|
||||||
if auto_describe is None or is_true(auto_describe):
|
if auto_describe is None or is_true(auto_describe):
|
||||||
@ -117,7 +142,10 @@ async def run_action():
|
|||||||
get_logger().info(f"Skipping action: {action}")
|
get_logger().info(f"Skipping action: {action}")
|
||||||
|
|
||||||
# Handle issue comment event
|
# Handle issue comment event
|
||||||
elif GITHUB_EVENT_NAME == "issue_comment" or GITHUB_EVENT_NAME == "pull_request_review_comment":
|
elif (
|
||||||
|
GITHUB_EVENT_NAME == "issue_comment"
|
||||||
|
or GITHUB_EVENT_NAME == "pull_request_review_comment"
|
||||||
|
):
|
||||||
action = event_payload.get("action")
|
action = event_payload.get("action")
|
||||||
if action in ["created", "edited"]:
|
if action in ["created", "edited"]:
|
||||||
comment_body = event_payload.get("comment", {}).get("body")
|
comment_body = event_payload.get("comment", {}).get("body")
|
||||||
@ -133,9 +161,15 @@ async def run_action():
|
|||||||
disable_eyes = False
|
disable_eyes = False
|
||||||
# check if issue is pull request
|
# check if issue is pull request
|
||||||
if event_payload.get("issue", {}).get("pull_request"):
|
if event_payload.get("issue", {}).get("pull_request"):
|
||||||
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
|
url = (
|
||||||
|
event_payload.get("issue", {})
|
||||||
|
.get("pull_request", {})
|
||||||
|
.get("url")
|
||||||
|
)
|
||||||
is_pr = True
|
is_pr = True
|
||||||
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
|
elif event_payload.get("comment", {}).get(
|
||||||
|
"pull_request_url"
|
||||||
|
): # for 'pull_request_review_comment
|
||||||
url = event_payload.get("comment", {}).get("pull_request_url")
|
url = event_payload.get("comment", {}).get("pull_request_url")
|
||||||
is_pr = True
|
is_pr = True
|
||||||
disable_eyes = True
|
disable_eyes = True
|
||||||
@ -148,9 +182,11 @@ async def run_action():
|
|||||||
provider = get_git_provider()(pr_url=url)
|
provider = get_git_provider()(pr_url=url)
|
||||||
if is_pr:
|
if is_pr:
|
||||||
await PRAgent().handle_request(
|
await PRAgent().handle_request(
|
||||||
url, body, notify=lambda: provider.add_eyes_reaction(
|
url,
|
||||||
|
body,
|
||||||
|
notify=lambda: provider.add_eyes_reaction(
|
||||||
comment_id, disable_eyes=disable_eyes
|
comment_id, disable_eyes=disable_eyes
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await PRAgent().handle_request(url, body)
|
await PRAgent().handle_request(url, body)
|
||||||
|
|||||||
@ -15,8 +15,7 @@ from starlette_context.middleware import RawContextMiddleware
|
|||||||
from utils.pr_agent.agent.pr_agent import PRAgent
|
from utils.pr_agent.agent.pr_agent import PRAgent
|
||||||
from utils.pr_agent.algo.utils import update_settings_from_args
|
from utils.pr_agent.algo.utils import update_settings_from_args
|
||||||
from utils.pr_agent.config_loader import get_settings, global_settings
|
from utils.pr_agent.config_loader import get_settings, global_settings
|
||||||
from utils.pr_agent.git_providers import (get_git_provider,
|
from utils.pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||||
get_git_provider_with_context)
|
|
||||||
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
from utils.pr_agent.git_providers.utils import apply_repo_settings
|
||||||
from utils.pr_agent.identity_providers import get_identity_provider
|
from utils.pr_agent.identity_providers import get_identity_provider
|
||||||
from utils.pr_agent.identity_providers.identity_provider import Eligibility
|
from utils.pr_agent.identity_providers.identity_provider import Eligibility
|
||||||
@ -35,7 +34,9 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/api/v1/github_webhooks")
|
@router.post("/api/v1/github_webhooks")
|
||||||
async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Request, response: Response):
|
async def handle_github_webhooks(
|
||||||
|
background_tasks: BackgroundTasks, request: Request, response: Response
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Receives and processes incoming GitHub webhook requests.
|
Receives and processes incoming GitHub webhook requests.
|
||||||
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
|
Verifies the request signature, parses the request body, and passes it to the handle_request function for further
|
||||||
@ -49,7 +50,9 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
|||||||
context["installation_id"] = installation_id
|
context["installation_id"] = installation_id
|
||||||
context["settings"] = copy.deepcopy(global_settings)
|
context["settings"] = copy.deepcopy(global_settings)
|
||||||
context["git_provider"] = {}
|
context["git_provider"] = {}
|
||||||
background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None))
|
background_tasks.add_task(
|
||||||
|
handle_request, body, event=request.headers.get("X-GitHub-Event", None)
|
||||||
|
)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@ -73,35 +76,61 @@ async def get_body(request):
|
|||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
_duplicate_push_triggers = DefaultDictWithTimeout(ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
|
_duplicate_push_triggers = DefaultDictWithTimeout(
|
||||||
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(asyncio.locks.Condition, ttl=get_settings().github_app.push_trigger_pending_tasks_ttl)
|
ttl=get_settings().github_app.push_trigger_pending_tasks_ttl
|
||||||
|
)
|
||||||
|
_pending_task_duplicate_push_conditions = DefaultDictWithTimeout(
|
||||||
|
asyncio.locks.Condition,
|
||||||
|
ttl=get_settings().github_app.push_trigger_pending_tasks_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
async def handle_comments_on_pr(body: Dict[str, Any],
|
|
||||||
event: str,
|
async def handle_comments_on_pr(
|
||||||
sender: str,
|
body: Dict[str, Any],
|
||||||
sender_id: str,
|
event: str,
|
||||||
action: str,
|
sender: str,
|
||||||
log_context: Dict[str, Any],
|
sender_id: str,
|
||||||
agent: PRAgent):
|
action: str,
|
||||||
|
log_context: Dict[str, Any],
|
||||||
|
agent: PRAgent,
|
||||||
|
):
|
||||||
if "comment" not in body:
|
if "comment" not in body:
|
||||||
return {}
|
return {}
|
||||||
comment_body = body.get("comment", {}).get("body")
|
comment_body = body.get("comment", {}).get("body")
|
||||||
if comment_body and isinstance(comment_body, str) and not comment_body.lstrip().startswith("/"):
|
if (
|
||||||
|
comment_body
|
||||||
|
and isinstance(comment_body, str)
|
||||||
|
and not comment_body.lstrip().startswith("/")
|
||||||
|
):
|
||||||
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
|
if '/ask' in comment_body and comment_body.strip().startswith('> ![image]'):
|
||||||
comment_body_split = comment_body.split('/ask')
|
comment_body_split = comment_body.split('/ask')
|
||||||
comment_body = '/ask' + comment_body_split[1] +' \n' +comment_body_split[0].strip().lstrip('>')
|
comment_body = (
|
||||||
get_logger().info(f"Reformatting comment_body so command is at the beginning: {comment_body}")
|
'/ask'
|
||||||
|
+ comment_body_split[1]
|
||||||
|
+ ' \n'
|
||||||
|
+ comment_body_split[0].strip().lstrip('>')
|
||||||
|
)
|
||||||
|
get_logger().info(
|
||||||
|
f"Reformatting comment_body so command is at the beginning: {comment_body}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().info("Ignoring comment not starting with /")
|
get_logger().info("Ignoring comment not starting with /")
|
||||||
return {}
|
return {}
|
||||||
disable_eyes = False
|
disable_eyes = False
|
||||||
if "issue" in body and "pull_request" in body["issue"] and "url" in body["issue"]["pull_request"]:
|
if (
|
||||||
|
"issue" in body
|
||||||
|
and "pull_request" in body["issue"]
|
||||||
|
and "url" in body["issue"]["pull_request"]
|
||||||
|
):
|
||||||
api_url = body["issue"]["pull_request"]["url"]
|
api_url = body["issue"]["pull_request"]["url"]
|
||||||
elif "comment" in body and "pull_request_url" in body["comment"]:
|
elif "comment" in body and "pull_request_url" in body["comment"]:
|
||||||
api_url = body["comment"]["pull_request_url"]
|
api_url = body["comment"]["pull_request_url"]
|
||||||
try:
|
try:
|
||||||
if ('/ask' in comment_body and
|
if (
|
||||||
'subject_type' in body["comment"] and body["comment"]["subject_type"] == "line"):
|
'/ask' in comment_body
|
||||||
|
and 'subject_type' in body["comment"]
|
||||||
|
and body["comment"]["subject_type"] == "line"
|
||||||
|
):
|
||||||
# comment on a code line in the "files changed" tab
|
# comment on a code line in the "files changed" tab
|
||||||
comment_body = handle_line_comments(body, comment_body)
|
comment_body = handle_line_comments(body, comment_body)
|
||||||
disable_eyes = True
|
disable_eyes = True
|
||||||
@ -113,46 +142,75 @@ async def handle_comments_on_pr(body: Dict[str, Any],
|
|||||||
comment_id = body.get("comment", {}).get("id")
|
comment_id = body.get("comment", {}).get("id")
|
||||||
provider = get_git_provider_with_context(pr_url=api_url)
|
provider = get_git_provider_with_context(pr_url=api_url)
|
||||||
with get_logger().contextualize(**log_context):
|
with get_logger().contextualize(**log_context):
|
||||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
if (
|
||||||
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
|
get_identity_provider().verify_eligibility("github", sender_id, api_url)
|
||||||
await agent.handle_request(api_url, comment_body,
|
is not Eligibility.NOT_ELIGIBLE
|
||||||
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
|
):
|
||||||
|
get_logger().info(
|
||||||
|
f"Processing comment on PR {api_url=}, comment_body={comment_body}"
|
||||||
|
)
|
||||||
|
await agent.handle_request(
|
||||||
|
api_url,
|
||||||
|
comment_body,
|
||||||
|
notify=lambda: provider.add_eyes_reaction(
|
||||||
|
comment_id, disable_eyes=disable_eyes
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().info(f"User {sender=} is not eligible to process comment on PR {api_url=}")
|
get_logger().info(
|
||||||
|
f"User {sender=} is not eligible to process comment on PR {api_url=}"
|
||||||
|
)
|
||||||
|
|
||||||
async def handle_new_pr_opened(body: Dict[str, Any],
|
|
||||||
event: str,
|
async def handle_new_pr_opened(
|
||||||
sender: str,
|
body: Dict[str, Any],
|
||||||
sender_id: str,
|
event: str,
|
||||||
action: str,
|
sender: str,
|
||||||
log_context: Dict[str, Any],
|
sender_id: str,
|
||||||
agent: PRAgent):
|
action: str,
|
||||||
|
log_context: Dict[str, Any],
|
||||||
|
agent: PRAgent,
|
||||||
|
):
|
||||||
title = body.get("pull_request", {}).get("title", "")
|
title = body.get("pull_request", {}).get("title", "")
|
||||||
|
|
||||||
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
||||||
if not (pull_request and api_url):
|
if not (pull_request and api_url):
|
||||||
get_logger().info(f"Invalid PR event: {action=} {api_url=}")
|
get_logger().info(f"Invalid PR event: {action=} {api_url=}")
|
||||||
return {}
|
return {}
|
||||||
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review']
|
if (
|
||||||
|
action in get_settings().github_app.handle_pr_actions
|
||||||
|
): # ['opened', 'reopened', 'ready_for_review']
|
||||||
# logic to ignore PRs with specific titles (e.g. "[Auto] ...")
|
# logic to ignore PRs with specific titles (e.g. "[Auto] ...")
|
||||||
apply_repo_settings(api_url)
|
apply_repo_settings(api_url)
|
||||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
if (
|
||||||
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
|
get_identity_provider().verify_eligibility("github", sender_id, api_url)
|
||||||
|
is not Eligibility.NOT_ELIGIBLE
|
||||||
|
):
|
||||||
|
await _perform_auto_commands_github(
|
||||||
|
"pr_commands", agent, body, api_url, log_context
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}")
|
get_logger().info(
|
||||||
|
f"User {sender=} is not eligible to process PR {api_url=}"
|
||||||
|
)
|
||||||
|
|
||||||
async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
|
||||||
event: str,
|
async def handle_push_trigger_for_new_commits(
|
||||||
sender: str,
|
body: Dict[str, Any],
|
||||||
sender_id: str,
|
event: str,
|
||||||
action: str,
|
sender: str,
|
||||||
log_context: Dict[str, Any],
|
sender_id: str,
|
||||||
agent: PRAgent):
|
action: str,
|
||||||
|
log_context: Dict[str, Any],
|
||||||
|
agent: PRAgent,
|
||||||
|
):
|
||||||
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
pull_request, api_url = _check_pull_request_event(action, body, log_context)
|
||||||
if not (pull_request and api_url):
|
if not (pull_request and api_url):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
apply_repo_settings(api_url) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
|
apply_repo_settings(
|
||||||
|
api_url
|
||||||
|
) # we need to apply the repo settings to get the correct settings for the PR. This is quite expensive - a call to the git provider is made for each PR event.
|
||||||
if not get_settings().github_app.handle_push_trigger:
|
if not get_settings().github_app.handle_push_trigger:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -162,7 +220,10 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
|||||||
merge_commit_sha = pull_request.get("merge_commit_sha")
|
merge_commit_sha = pull_request.get("merge_commit_sha")
|
||||||
if before_sha == after_sha:
|
if before_sha == after_sha:
|
||||||
return {}
|
return {}
|
||||||
if get_settings().github_app.push_trigger_ignore_merge_commits and after_sha == merge_commit_sha:
|
if (
|
||||||
|
get_settings().github_app.push_trigger_ignore_merge_commits
|
||||||
|
and after_sha == merge_commit_sha
|
||||||
|
):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Prevent triggering multiple times for subsequent push triggers when one is enough:
|
# Prevent triggering multiple times for subsequent push triggers when one is enough:
|
||||||
@ -172,7 +233,9 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
|||||||
# more commits may have been pushed that led to the subsequent events,
|
# more commits may have been pushed that led to the subsequent events,
|
||||||
# so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting.
|
# so we keep just one waiting as a delegate to trigger the processing for the new commits when done waiting.
|
||||||
current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0)
|
current_active_tasks = _duplicate_push_triggers.setdefault(api_url, 0)
|
||||||
max_active_tasks = 2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
|
max_active_tasks = (
|
||||||
|
2 if get_settings().github_app.push_trigger_pending_tasks_backlog else 1
|
||||||
|
)
|
||||||
if current_active_tasks < max_active_tasks:
|
if current_active_tasks < max_active_tasks:
|
||||||
# first task can enter, and second tasks too if backlog is enabled
|
# first task can enter, and second tasks too if backlog is enabled
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
@ -191,12 +254,21 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
|||||||
f"Waiting to process push trigger for {api_url=} because the first task is still in progress"
|
f"Waiting to process push trigger for {api_url=} because the first task is still in progress"
|
||||||
)
|
)
|
||||||
await _pending_task_duplicate_push_conditions[api_url].wait()
|
await _pending_task_duplicate_push_conditions[api_url].wait()
|
||||||
get_logger().info(f"Finished waiting to process push trigger for {api_url=} - continue with flow")
|
get_logger().info(
|
||||||
|
f"Finished waiting to process push trigger for {api_url=} - continue with flow"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
if (
|
||||||
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
|
get_identity_provider().verify_eligibility("github", sender_id, api_url)
|
||||||
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
|
is not Eligibility.NOT_ELIGIBLE
|
||||||
|
):
|
||||||
|
get_logger().info(
|
||||||
|
f"Performing incremental review for {api_url=} because of {event=} and {action=}"
|
||||||
|
)
|
||||||
|
await _perform_auto_commands_github(
|
||||||
|
"push_commands", agent, body, api_url, log_context
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# release the waiting task block
|
# release the waiting task block
|
||||||
@ -213,7 +285,12 @@ def handle_closed_pr(body, event, action, log_context):
|
|||||||
api_url = pull_request.get("url", "")
|
api_url = pull_request.get("url", "")
|
||||||
pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request)
|
pr_statistics = get_git_provider()(pr_url=api_url).calc_pr_statistics(pull_request)
|
||||||
log_context["api_url"] = api_url
|
log_context["api_url"] = api_url
|
||||||
get_logger().info("PR-Agent statistics for closed PR", analytics=True, pr_statistics=pr_statistics, **log_context)
|
get_logger().info(
|
||||||
|
"PR-Agent statistics for closed PR",
|
||||||
|
analytics=True,
|
||||||
|
pr_statistics=pr_statistics,
|
||||||
|
**log_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_log_context(body, event, action, build_number):
|
def get_log_context(body, event, action, build_number):
|
||||||
@ -228,9 +305,18 @@ def get_log_context(body, event, action, build_number):
|
|||||||
git_org = body.get("organization", {}).get("login", "")
|
git_org = body.get("organization", {}).get("login", "")
|
||||||
installation_id = body.get("installation", {}).get("id", "")
|
installation_id = body.get("installation", {}).get("id", "")
|
||||||
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
|
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
|
||||||
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app",
|
log_context = {
|
||||||
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name,
|
"action": action,
|
||||||
"repo": repo, "git_org": git_org, "installation_id": installation_id}
|
"event": event,
|
||||||
|
"sender": sender,
|
||||||
|
"server_type": "github_app",
|
||||||
|
"request_id": uuid.uuid4().hex,
|
||||||
|
"build_number": build_number,
|
||||||
|
"app_name": app_name,
|
||||||
|
"repo": repo,
|
||||||
|
"git_org": git_org,
|
||||||
|
"installation_id": installation_id,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error("Failed to get log context", e)
|
get_logger().error("Failed to get log context", e)
|
||||||
log_context = {}
|
log_context = {}
|
||||||
@ -240,7 +326,10 @@ def get_log_context(body, event, action, build_number):
|
|||||||
def is_bot_user(sender, sender_type):
|
def is_bot_user(sender, sender_type):
|
||||||
try:
|
try:
|
||||||
# logic to ignore PRs opened by bot
|
# logic to ignore PRs opened by bot
|
||||||
if get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False) and sender_type == "Bot":
|
if (
|
||||||
|
get_settings().get("GITHUB_APP.IGNORE_BOT_PR", False)
|
||||||
|
and sender_type == "Bot"
|
||||||
|
):
|
||||||
if 'pr-agent' not in sender:
|
if 'pr-agent' not in sender:
|
||||||
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
|
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
|
||||||
return True
|
return True
|
||||||
@ -262,7 +351,9 @@ def should_process_pr_logic(body) -> bool:
|
|||||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||||
if ignore_pr_users and sender:
|
if ignore_pr_users and sender:
|
||||||
if sender in ignore_pr_users:
|
if sender in ignore_pr_users:
|
||||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting")
|
get_logger().info(
|
||||||
|
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' setting"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# logic to ignore PRs with specific titles
|
# logic to ignore PRs with specific titles
|
||||||
@ -270,8 +361,12 @@ def should_process_pr_logic(body) -> bool:
|
|||||||
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
ignore_pr_title_re = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||||
if not isinstance(ignore_pr_title_re, list):
|
if not isinstance(ignore_pr_title_re, list):
|
||||||
ignore_pr_title_re = [ignore_pr_title_re]
|
ignore_pr_title_re = [ignore_pr_title_re]
|
||||||
if ignore_pr_title_re and any(re.search(regex, title) for regex in ignore_pr_title_re):
|
if ignore_pr_title_re and any(
|
||||||
get_logger().info(f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting")
|
re.search(regex, title) for regex in ignore_pr_title_re
|
||||||
|
):
|
||||||
|
get_logger().info(
|
||||||
|
f"Ignoring PR with title '{title}' due to config.ignore_pr_title setting"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# logic to ignore PRs with specific labels or source branches or target branches.
|
# logic to ignore PRs with specific labels or source branches or target branches.
|
||||||
@ -280,20 +375,32 @@ def should_process_pr_logic(body) -> bool:
|
|||||||
labels = [label['name'] for label in pr_labels]
|
labels = [label['name'] for label in pr_labels]
|
||||||
if any(label in ignore_pr_labels for label in labels):
|
if any(label in ignore_pr_labels for label in labels):
|
||||||
labels_str = ", ".join(labels)
|
labels_str = ", ".join(labels)
|
||||||
get_logger().info(f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings")
|
get_logger().info(
|
||||||
|
f"Ignoring PR with labels '{labels_str}' due to config.ignore_pr_labels settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# logic to ignore PRs with specific source or target branches
|
# logic to ignore PRs with specific source or target branches
|
||||||
ignore_pr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
ignore_pr_source_branches = get_settings().get(
|
||||||
ignore_pr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
|
||||||
|
)
|
||||||
|
ignore_pr_target_branches = get_settings().get(
|
||||||
|
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
|
||||||
|
)
|
||||||
if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches):
|
if pull_request and (ignore_pr_source_branches or ignore_pr_target_branches):
|
||||||
if any(re.search(regex, source_branch) for regex in ignore_pr_source_branches):
|
if any(
|
||||||
|
re.search(regex, source_branch) for regex in ignore_pr_source_branches
|
||||||
|
):
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings")
|
f"Ignoring PR with source branch '{source_branch}' due to config.ignore_pr_source_branches settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
if any(re.search(regex, target_branch) for regex in ignore_pr_target_branches):
|
if any(
|
||||||
|
re.search(regex, target_branch) for regex in ignore_pr_target_branches
|
||||||
|
):
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings")
|
f"Ignoring PR with target branch '{target_branch}' due to config.ignore_pr_target_branches settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||||
@ -308,11 +415,15 @@ async def handle_request(body: Dict[str, Any], event: str):
|
|||||||
body: The request body.
|
body: The request body.
|
||||||
event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.).
|
event: The GitHub event type (e.g. "pull_request", "issue_comment", etc.).
|
||||||
"""
|
"""
|
||||||
action = body.get("action") # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
|
action = body.get(
|
||||||
|
"action"
|
||||||
|
) # "created", "opened", "reopened", "ready_for_review", "review_requested", "synchronize"
|
||||||
if not action:
|
if not action:
|
||||||
return {}
|
return {}
|
||||||
agent = PRAgent()
|
agent = PRAgent()
|
||||||
log_context, sender, sender_id, sender_type = get_log_context(body, event, action, build_number)
|
log_context, sender, sender_id, sender_type = get_log_context(
|
||||||
|
body, event, action, build_number
|
||||||
|
)
|
||||||
|
|
||||||
# logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches
|
# logic to ignore PRs opened by bot, PRs with specific titles, labels, source branches, or target branches
|
||||||
if is_bot_user(sender, sender_type) and 'check_run' not in body:
|
if is_bot_user(sender, sender_type) and 'check_run' not in body:
|
||||||
@ -327,21 +438,29 @@ async def handle_request(body: Dict[str, Any], event: str):
|
|||||||
# handle comments on PRs
|
# handle comments on PRs
|
||||||
elif action == 'created':
|
elif action == 'created':
|
||||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||||
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent)
|
await handle_comments_on_pr(
|
||||||
|
body, event, sender, sender_id, action, log_context, agent
|
||||||
|
)
|
||||||
# handle new PRs
|
# handle new PRs
|
||||||
elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
|
elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
|
||||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||||
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent)
|
await handle_new_pr_opened(
|
||||||
|
body, event, sender, sender_id, action, log_context, agent
|
||||||
|
)
|
||||||
elif event == "issue_comment" and 'edited' in action:
|
elif event == "issue_comment" and 'edited' in action:
|
||||||
pass # handle_checkbox_clicked
|
pass # handle_checkbox_clicked
|
||||||
# handle pull_request event with synchronize action - "push trigger" for new commits
|
# handle pull_request event with synchronize action - "push trigger" for new commits
|
||||||
elif event == 'pull_request' and action == 'synchronize':
|
elif event == 'pull_request' and action == 'synchronize':
|
||||||
await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent)
|
await handle_push_trigger_for_new_commits(
|
||||||
|
body, event, sender, sender_id, action, log_context, agent
|
||||||
|
)
|
||||||
elif event == 'pull_request' and action == 'closed':
|
elif event == 'pull_request' and action == 'closed':
|
||||||
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
|
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
|
||||||
handle_closed_pr(body, event, action, log_context)
|
handle_closed_pr(body, event, action, log_context)
|
||||||
else:
|
else:
|
||||||
get_logger().info(f"event {event=} action {action=} does not require any handling")
|
get_logger().info(
|
||||||
|
f"event {event=} action {action=} does not require any handling"
|
||||||
|
)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@ -362,7 +481,9 @@ def handle_line_comments(body: Dict, comment_body: [str, Any]) -> str:
|
|||||||
return comment_body
|
return comment_body
|
||||||
|
|
||||||
|
|
||||||
def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tuple[Dict[str, Any], str]:
|
def _check_pull_request_event(
|
||||||
|
action: str, body: dict, log_context: dict
|
||||||
|
) -> Tuple[Dict[str, Any], str]:
|
||||||
invalid_result = {}, ""
|
invalid_result = {}, ""
|
||||||
pull_request = body.get("pull_request")
|
pull_request = body.get("pull_request")
|
||||||
if not pull_request:
|
if not pull_request:
|
||||||
@ -373,19 +494,28 @@ def _check_pull_request_event(action: str, body: dict, log_context: dict) -> Tup
|
|||||||
log_context["api_url"] = api_url
|
log_context["api_url"] = api_url
|
||||||
if pull_request.get("draft", True) or pull_request.get("state") != "open":
|
if pull_request.get("draft", True) or pull_request.get("state") != "open":
|
||||||
return invalid_result
|
return invalid_result
|
||||||
if action in ("review_requested", "synchronize") and pull_request.get("created_at") == pull_request.get("updated_at"):
|
if action in ("review_requested", "synchronize") and pull_request.get(
|
||||||
|
"created_at"
|
||||||
|
) == pull_request.get("updated_at"):
|
||||||
# avoid double reviews when opening a PR for the first time
|
# avoid double reviews when opening a PR for the first time
|
||||||
return invalid_result
|
return invalid_result
|
||||||
return pull_request, api_url
|
return pull_request, api_url
|
||||||
|
|
||||||
|
|
||||||
async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body: dict, api_url: str,
|
async def _perform_auto_commands_github(
|
||||||
log_context: dict):
|
commands_conf: str, agent: PRAgent, body: dict, api_url: str, log_context: dict
|
||||||
|
):
|
||||||
apply_repo_settings(api_url)
|
apply_repo_settings(api_url)
|
||||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
if (
|
||||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}")
|
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||||
|
): # auto commands for PR, and auto feedback is disabled
|
||||||
|
get_logger().info(
|
||||||
|
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
if not should_process_pr_logic(body): # Here we already updated the configuration with the repo settings
|
if not should_process_pr_logic(
|
||||||
|
body
|
||||||
|
): # Here we already updated the configuration with the repo settings
|
||||||
return {}
|
return {}
|
||||||
commands = get_settings().get(f"github_app.{commands_conf}")
|
commands = get_settings().get(f"github_app.{commands_conf}")
|
||||||
if not commands:
|
if not commands:
|
||||||
@ -398,7 +528,9 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body
|
|||||||
args = split_command[1:]
|
args = split_command[1:]
|
||||||
other_args = update_settings_from_args(args)
|
other_args = update_settings_from_args(args)
|
||||||
new_command = ' '.join([command] + other_args)
|
new_command = ' '.join([command] + other_args)
|
||||||
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}")
|
get_logger().info(
|
||||||
|
f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}"
|
||||||
|
)
|
||||||
await agent.handle_request(api_url, new_command)
|
await agent.handle_request(api_url, new_command)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,11 +18,13 @@ NOTIFICATION_URL = "https://api.github.com/notifications"
|
|||||||
|
|
||||||
async def mark_notification_as_read(headers, notification, session):
|
async def mark_notification_as_read(headers, notification, session):
|
||||||
async with session.patch(
|
async with session.patch(
|
||||||
f"https://api.github.com/notifications/threads/{notification['id']}",
|
f"https://api.github.com/notifications/threads/{notification['id']}",
|
||||||
headers=headers) as mark_read_response:
|
headers=headers,
|
||||||
|
) as mark_read_response:
|
||||||
if mark_read_response.status != 205:
|
if mark_read_response.status != 205:
|
||||||
get_logger().error(
|
get_logger().error(
|
||||||
f"Failed to mark notification as read. Status code: {mark_read_response.status}")
|
f"Failed to mark notification as read. Status code: {mark_read_response.status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def now() -> str:
|
def now() -> str:
|
||||||
@ -36,17 +38,21 @@ def now() -> str:
|
|||||||
now_utc = now_utc.replace("+00:00", "Z")
|
now_utc = now_utc.replace("+00:00", "Z")
|
||||||
return now_utc
|
return now_utc
|
||||||
|
|
||||||
|
|
||||||
async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
||||||
agent = PRAgent()
|
agent = PRAgent()
|
||||||
success = await agent.handle_request(
|
success = await agent.handle_request(
|
||||||
pr_url,
|
pr_url,
|
||||||
rest_of_comment,
|
rest_of_comment,
|
||||||
notify=lambda: git_provider.add_eyes_reaction(comment_id)
|
notify=lambda: git_provider.add_eyes_reaction(comment_id),
|
||||||
)
|
)
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
|
||||||
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider))
|
return asyncio.run(
|
||||||
|
async_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_comment_sync(pr_url, rest_of_comment, comment_id):
|
def process_comment_sync(pr_url, rest_of_comment, comment_id):
|
||||||
@ -55,7 +61,10 @@ def process_comment_sync(pr_url, rest_of_comment, comment_id):
|
|||||||
git_provider = get_git_provider()(pr_url=pr_url)
|
git_provider = get_git_provider()(pr_url=pr_url)
|
||||||
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
|
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
|
get_logger().error(
|
||||||
|
f"Error processing comment: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_comment(pr_url, rest_of_comment, comment_id):
|
async def process_comment(pr_url, rest_of_comment, comment_id):
|
||||||
@ -66,22 +75,31 @@ async def process_comment(pr_url, rest_of_comment, comment_id):
|
|||||||
success = await agent.handle_request(
|
success = await agent.handle_request(
|
||||||
pr_url,
|
pr_url,
|
||||||
rest_of_comment,
|
rest_of_comment,
|
||||||
notify=lambda: git_provider.add_eyes_reaction(comment_id)
|
notify=lambda: git_provider.add_eyes_reaction(comment_id),
|
||||||
)
|
)
|
||||||
get_logger().info(f"Finished processing comment for PR: {pr_url}")
|
get_logger().info(f"Finished processing comment for PR: {pr_url}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})
|
get_logger().error(
|
||||||
|
f"Error processing comment: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def is_valid_notification(notification, headers, handled_ids, session, user_id):
|
async def is_valid_notification(notification, headers, handled_ids, session, user_id):
|
||||||
try:
|
try:
|
||||||
if 'reason' in notification and notification['reason'] == 'mention':
|
if 'reason' in notification and notification['reason'] == 'mention':
|
||||||
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
|
if (
|
||||||
|
'subject' in notification
|
||||||
|
and notification['subject']['type'] == 'PullRequest'
|
||||||
|
):
|
||||||
pr_url = notification['subject']['url']
|
pr_url = notification['subject']['url']
|
||||||
latest_comment = notification['subject']['latest_comment_url']
|
latest_comment = notification['subject']['latest_comment_url']
|
||||||
if not latest_comment or not isinstance(latest_comment, str):
|
if not latest_comment or not isinstance(latest_comment, str):
|
||||||
get_logger().debug(f"no latest_comment")
|
get_logger().debug(f"no latest_comment")
|
||||||
return False, handled_ids
|
return False, handled_ids
|
||||||
async with session.get(latest_comment, headers=headers) as comment_response:
|
async with session.get(
|
||||||
|
latest_comment, headers=headers
|
||||||
|
) as comment_response:
|
||||||
check_prev_comments = False
|
check_prev_comments = False
|
||||||
user_tag = "@" + user_id
|
user_tag = "@" + user_id
|
||||||
if comment_response.status == 200:
|
if comment_response.status == 200:
|
||||||
@ -94,7 +112,9 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
|
|||||||
handled_ids.add(comment['id'])
|
handled_ids.add(comment['id'])
|
||||||
if 'user' in comment and 'login' in comment['user']:
|
if 'user' in comment and 'login' in comment['user']:
|
||||||
if comment['user']['login'] == user_id:
|
if comment['user']['login'] == user_id:
|
||||||
get_logger().debug(f"comment['user']['login'] == user_id")
|
get_logger().debug(
|
||||||
|
f"comment['user']['login'] == user_id"
|
||||||
|
)
|
||||||
check_prev_comments = True
|
check_prev_comments = True
|
||||||
comment_body = comment.get('body', '')
|
comment_body = comment.get('body', '')
|
||||||
if not comment_body:
|
if not comment_body:
|
||||||
@ -105,15 +125,28 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
|
|||||||
get_logger().debug(f"user_tag not in comment_body")
|
get_logger().debug(f"user_tag not in comment_body")
|
||||||
check_prev_comments = True
|
check_prev_comments = True
|
||||||
else:
|
else:
|
||||||
get_logger().info(f"Polling, pr_url: {pr_url}",
|
get_logger().info(
|
||||||
artifact={"comment": comment_body})
|
f"Polling, pr_url: {pr_url}",
|
||||||
|
artifact={"comment": comment_body},
|
||||||
|
)
|
||||||
|
|
||||||
if not check_prev_comments:
|
if not check_prev_comments:
|
||||||
return True, handled_ids, comment, comment_body, pr_url, user_tag
|
return (
|
||||||
else: # we could not find the user tag in the latest comment. Check previous comments
|
True,
|
||||||
|
handled_ids,
|
||||||
|
comment,
|
||||||
|
comment_body,
|
||||||
|
pr_url,
|
||||||
|
user_tag,
|
||||||
|
)
|
||||||
|
else: # we could not find the user tag in the latest comment. Check previous comments
|
||||||
# get all comments in the PR
|
# get all comments in the PR
|
||||||
requests_url = f"{pr_url}/comments".replace("pulls", "issues")
|
requests_url = f"{pr_url}/comments".replace(
|
||||||
comments_response = requests.get(requests_url, headers=headers)
|
"pulls", "issues"
|
||||||
|
)
|
||||||
|
comments_response = requests.get(
|
||||||
|
requests_url, headers=headers
|
||||||
|
)
|
||||||
comments = comments_response.json()[::-1]
|
comments = comments_response.json()[::-1]
|
||||||
max_comment_to_scan = 4
|
max_comment_to_scan = 4
|
||||||
for comment in comments[:max_comment_to_scan]:
|
for comment in comments[:max_comment_to_scan]:
|
||||||
@ -124,23 +157,37 @@ async def is_valid_notification(notification, headers, handled_ids, session, use
|
|||||||
if not comment_body:
|
if not comment_body:
|
||||||
continue
|
continue
|
||||||
if user_tag in comment_body:
|
if user_tag in comment_body:
|
||||||
get_logger().info("found user tag in previous comments")
|
get_logger().info(
|
||||||
get_logger().info(f"Polling, pr_url: {pr_url}",
|
"found user tag in previous comments"
|
||||||
artifact={"comment": comment_body})
|
)
|
||||||
return True, handled_ids, comment, comment_body, pr_url, user_tag
|
get_logger().info(
|
||||||
|
f"Polling, pr_url: {pr_url}",
|
||||||
|
artifact={"comment": comment_body},
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
True,
|
||||||
|
handled_ids,
|
||||||
|
comment,
|
||||||
|
comment_body,
|
||||||
|
pr_url,
|
||||||
|
user_tag,
|
||||||
|
)
|
||||||
|
|
||||||
get_logger().warning(f"Failed to fetch comments for PR: {pr_url}",
|
get_logger().warning(
|
||||||
artifact={"comments": comments})
|
f"Failed to fetch comments for PR: {pr_url}",
|
||||||
|
artifact={"comments": comments},
|
||||||
|
)
|
||||||
return False, handled_ids
|
return False, handled_ids
|
||||||
|
|
||||||
return False, handled_ids
|
return False, handled_ids
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Error processing polling notification",
|
get_logger().exception(
|
||||||
artifact={"notification": notification, "error": e})
|
f"Error processing polling notification",
|
||||||
|
artifact={"notification": notification, "error": e},
|
||||||
|
)
|
||||||
return False, handled_ids
|
return False, handled_ids
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def polling_loop():
|
async def polling_loop():
|
||||||
"""
|
"""
|
||||||
Polls for notifications and handles them accordingly.
|
Polls for notifications and handles them accordingly.
|
||||||
@ -171,17 +218,17 @@ async def polling_loop():
|
|||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
"Authorization": f"Bearer {token}"
|
"Authorization": f"Bearer {token}",
|
||||||
}
|
|
||||||
params = {
|
|
||||||
"participating": "true"
|
|
||||||
}
|
}
|
||||||
|
params = {"participating": "true"}
|
||||||
if since[0]:
|
if since[0]:
|
||||||
params["since"] = since[0]
|
params["since"] = since[0]
|
||||||
if last_modified[0]:
|
if last_modified[0]:
|
||||||
headers["If-Modified-Since"] = last_modified[0]
|
headers["If-Modified-Since"] = last_modified[0]
|
||||||
|
|
||||||
async with session.get(NOTIFICATION_URL, headers=headers, params=params) as response:
|
async with session.get(
|
||||||
|
NOTIFICATION_URL, headers=headers, params=params
|
||||||
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
if 'Last-Modified' in response.headers:
|
if 'Last-Modified' in response.headers:
|
||||||
last_modified[0] = response.headers['Last-Modified']
|
last_modified[0] = response.headers['Last-Modified']
|
||||||
@ -189,39 +236,67 @@ async def polling_loop():
|
|||||||
notifications = await response.json()
|
notifications = await response.json()
|
||||||
if not notifications:
|
if not notifications:
|
||||||
continue
|
continue
|
||||||
get_logger().info(f"Received {len(notifications)} notifications")
|
get_logger().info(
|
||||||
|
f"Received {len(notifications)} notifications"
|
||||||
|
)
|
||||||
task_queue = deque()
|
task_queue = deque()
|
||||||
for notification in notifications:
|
for notification in notifications:
|
||||||
if not notification:
|
if not notification:
|
||||||
continue
|
continue
|
||||||
# mark notification as read
|
# mark notification as read
|
||||||
await mark_notification_as_read(headers, notification, session)
|
await mark_notification_as_read(
|
||||||
|
headers, notification, session
|
||||||
|
)
|
||||||
|
|
||||||
handled_ids.add(notification['id'])
|
handled_ids.add(notification['id'])
|
||||||
output = await is_valid_notification(notification, headers, handled_ids, session, user_id)
|
output = await is_valid_notification(
|
||||||
|
notification, headers, handled_ids, session, user_id
|
||||||
|
)
|
||||||
if output[0]:
|
if output[0]:
|
||||||
_, handled_ids, comment, comment_body, pr_url, user_tag = output
|
(
|
||||||
rest_of_comment = comment_body.split(user_tag)[1].strip()
|
_,
|
||||||
|
handled_ids,
|
||||||
|
comment,
|
||||||
|
comment_body,
|
||||||
|
pr_url,
|
||||||
|
user_tag,
|
||||||
|
) = output
|
||||||
|
rest_of_comment = comment_body.split(user_tag)[
|
||||||
|
1
|
||||||
|
].strip()
|
||||||
comment_id = comment['id']
|
comment_id = comment['id']
|
||||||
|
|
||||||
# Add to the task queue
|
# Add to the task queue
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}")
|
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}"
|
||||||
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id)))
|
)
|
||||||
get_logger().info(f"Queued comment processing for PR: {pr_url}")
|
task_queue.append(
|
||||||
|
(
|
||||||
|
process_comment_sync,
|
||||||
|
(pr_url, rest_of_comment, comment_id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
get_logger().info(
|
||||||
|
f"Queued comment processing for PR: {pr_url}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().debug(f"Skipping comment processing for PR")
|
get_logger().debug(
|
||||||
|
f"Skipping comment processing for PR"
|
||||||
|
)
|
||||||
|
|
||||||
max_allowed_parallel_tasks = 10
|
max_allowed_parallel_tasks = 10
|
||||||
if task_queue:
|
if task_queue:
|
||||||
processes = []
|
processes = []
|
||||||
for i, (func, args) in enumerate(task_queue): # Create parallel tasks
|
for i, (func, args) in enumerate(
|
||||||
|
task_queue
|
||||||
|
): # Create parallel tasks
|
||||||
p = multiprocessing.Process(target=func, args=args)
|
p = multiprocessing.Process(target=func, args=args)
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
if i > max_allowed_parallel_tasks:
|
if i > max_allowed_parallel_tasks:
|
||||||
get_logger().error(
|
get_logger().error(
|
||||||
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session")
|
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
task_queue.clear()
|
task_queue.clear()
|
||||||
|
|
||||||
@ -230,11 +305,15 @@ async def polling_loop():
|
|||||||
# p.join()
|
# p.join()
|
||||||
|
|
||||||
elif response.status != 304:
|
elif response.status != 304:
|
||||||
print(f"Failed to fetch notifications. Status code: {response.status}")
|
print(
|
||||||
|
f"Failed to fetch notifications. Status code: {response.status}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Polling exception during processing of a notification: {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Polling exception during processing of a notification: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -22,20 +22,21 @@ from utils.pr_agent.secret_providers import get_secret_provider
|
|||||||
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
setup_logger(fmt=LoggingFormat.JSON, level="DEBUG")
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
secret_provider = (
|
||||||
|
get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
|
async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
headers = {
|
|
||||||
'Private-Token': f'{gitlab_token}'
|
headers = {'Private-Token': f'{gitlab_token}'}
|
||||||
}
|
|
||||||
# API endpoint to find MRs containing the commit
|
# API endpoint to find MRs containing the commit
|
||||||
gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com')
|
gitlab_url = get_settings().get("GITLAB.URL", 'https://gitlab.com')
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests',
|
f'{gitlab_url}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/merge_requests',
|
||||||
headers=headers
|
headers=headers,
|
||||||
)
|
)
|
||||||
merge_requests = response.json()
|
merge_requests = response.json()
|
||||||
if merge_requests and response.status_code == 200:
|
if merge_requests and response.status_code == 200:
|
||||||
@ -48,6 +49,7 @@ async def get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id):
|
|||||||
get_logger().error(f"Failed to get MR url from commit sha: {e}")
|
get_logger().error(f"Failed to get MR url from commit sha: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str):
|
async def handle_request(api_url: str, body: str, log_context: dict, sender_id: str):
|
||||||
log_context["action"] = body
|
log_context["action"] = body
|
||||||
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
||||||
@ -58,13 +60,19 @@ async def handle_request(api_url: str, body: str, log_context: dict, sender_id:
|
|||||||
await PRAgent().handle_request(api_url, body)
|
await PRAgent().handle_request(api_url, body)
|
||||||
|
|
||||||
|
|
||||||
async def _perform_commands_gitlab(commands_conf: str, agent: PRAgent, api_url: str,
|
async def _perform_commands_gitlab(
|
||||||
log_context: dict, data: dict):
|
commands_conf: str, agent: PRAgent, api_url: str, log_context: dict, data: dict
|
||||||
|
):
|
||||||
apply_repo_settings(api_url)
|
apply_repo_settings(api_url)
|
||||||
if commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback: # auto commands for PR, and auto feedback is disabled
|
if (
|
||||||
get_logger().info(f"Auto feedback is disabled, skipping auto commands for PR {api_url=}", **log_context)
|
commands_conf == "pr_commands" and get_settings().config.disable_auto_feedback
|
||||||
|
): # auto commands for PR, and auto feedback is disabled
|
||||||
|
get_logger().info(
|
||||||
|
f"Auto feedback is disabled, skipping auto commands for PR {api_url=}",
|
||||||
|
**log_context,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
if not should_process_pr_logic(data): # Here we already updated the configurations
|
if not should_process_pr_logic(data): # Here we already updated the configurations
|
||||||
return
|
return
|
||||||
commands = get_settings().get(f"gitlab.{commands_conf}", {})
|
commands = get_settings().get(f"gitlab.{commands_conf}", {})
|
||||||
get_settings().set("config.is_auto_command", True)
|
get_settings().set("config.is_auto_command", True)
|
||||||
@ -106,40 +114,58 @@ def should_process_pr_logic(data) -> bool:
|
|||||||
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
ignore_pr_users = get_settings().get("CONFIG.IGNORE_PR_AUTHORS", [])
|
||||||
if ignore_pr_users and sender:
|
if ignore_pr_users and sender:
|
||||||
if sender in ignore_pr_users:
|
if sender in ignore_pr_users:
|
||||||
get_logger().info(f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings")
|
get_logger().info(
|
||||||
|
f"Ignoring PR from user '{sender}' due to 'config.ignore_pr_authors' settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# logic to ignore MRs for titles, labels and source, target branches.
|
# logic to ignore MRs for titles, labels and source, target branches.
|
||||||
ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
ignore_mr_title = get_settings().get("CONFIG.IGNORE_PR_TITLE", [])
|
||||||
ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
|
ignore_mr_labels = get_settings().get("CONFIG.IGNORE_PR_LABELS", [])
|
||||||
ignore_mr_source_branches = get_settings().get("CONFIG.IGNORE_PR_SOURCE_BRANCHES", [])
|
ignore_mr_source_branches = get_settings().get(
|
||||||
ignore_mr_target_branches = get_settings().get("CONFIG.IGNORE_PR_TARGET_BRANCHES", [])
|
"CONFIG.IGNORE_PR_SOURCE_BRANCHES", []
|
||||||
|
)
|
||||||
|
ignore_mr_target_branches = get_settings().get(
|
||||||
|
"CONFIG.IGNORE_PR_TARGET_BRANCHES", []
|
||||||
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
if ignore_mr_source_branches:
|
if ignore_mr_source_branches:
|
||||||
source_branch = data['object_attributes'].get('source_branch')
|
source_branch = data['object_attributes'].get('source_branch')
|
||||||
if any(re.search(regex, source_branch) for regex in ignore_mr_source_branches):
|
if any(
|
||||||
|
re.search(regex, source_branch) for regex in ignore_mr_source_branches
|
||||||
|
):
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings")
|
f"Ignoring MR with source branch '{source_branch}' due to gitlab.ignore_mr_source_branches settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if ignore_mr_target_branches:
|
if ignore_mr_target_branches:
|
||||||
target_branch = data['object_attributes'].get('target_branch')
|
target_branch = data['object_attributes'].get('target_branch')
|
||||||
if any(re.search(regex, target_branch) for regex in ignore_mr_target_branches):
|
if any(
|
||||||
|
re.search(regex, target_branch) for regex in ignore_mr_target_branches
|
||||||
|
):
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings")
|
f"Ignoring MR with target branch '{target_branch}' due to gitlab.ignore_mr_target_branches settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if ignore_mr_labels:
|
if ignore_mr_labels:
|
||||||
labels = [label['title'] for label in data['object_attributes'].get('labels', [])]
|
labels = [
|
||||||
|
label['title'] for label in data['object_attributes'].get('labels', [])
|
||||||
|
]
|
||||||
if any(label in ignore_mr_labels for label in labels):
|
if any(label in ignore_mr_labels for label in labels):
|
||||||
labels_str = ", ".join(labels)
|
labels_str = ", ".join(labels)
|
||||||
get_logger().info(f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings")
|
get_logger().info(
|
||||||
|
f"Ignoring MR with labels '{labels_str}' due to gitlab.ignore_mr_labels settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if ignore_mr_title:
|
if ignore_mr_title:
|
||||||
if any(re.search(regex, title) for regex in ignore_mr_title):
|
if any(re.search(regex, title) for regex in ignore_mr_title):
|
||||||
get_logger().info(f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings")
|
get_logger().info(
|
||||||
|
f"Ignoring MR with title '{title}' due to gitlab.ignore_mr_title settings"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
get_logger().error(f"Failed 'should_process_pr_logic': {e}")
|
||||||
@ -159,29 +185,47 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
request_token = request.headers.get("X-Gitlab-Token")
|
request_token = request.headers.get("X-Gitlab-Token")
|
||||||
secret = secret_provider.get_secret(request_token)
|
secret = secret_provider.get_secret(request_token)
|
||||||
if not secret:
|
if not secret:
|
||||||
get_logger().warning(f"Empty secret retrieved, request_token: {request_token}")
|
get_logger().warning(
|
||||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED,
|
f"Empty secret retrieved, request_token: {request_token}"
|
||||||
content=jsonable_encoder({"message": "unauthorized"}))
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content=jsonable_encoder({"message": "unauthorized"}),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
secret_dict = json.loads(secret)
|
secret_dict = json.loads(secret)
|
||||||
gitlab_token = secret_dict["gitlab_token"]
|
gitlab_token = secret_dict["gitlab_token"]
|
||||||
log_context["token_id"] = secret_dict.get("token_name", secret_dict.get("id", "unknown"))
|
log_context["token_id"] = secret_dict.get(
|
||||||
|
"token_name", secret_dict.get("id", "unknown")
|
||||||
|
)
|
||||||
context["settings"].gitlab.personal_access_token = gitlab_token
|
context["settings"].gitlab.personal_access_token = gitlab_token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to validate secret {request_token}: {e}")
|
get_logger().error(f"Failed to validate secret {request_token}: {e}")
|
||||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content=jsonable_encoder({"message": "unauthorized"}),
|
||||||
|
)
|
||||||
elif get_settings().get("GITLAB.SHARED_SECRET"):
|
elif get_settings().get("GITLAB.SHARED_SECRET"):
|
||||||
secret = get_settings().get("GITLAB.SHARED_SECRET")
|
secret = get_settings().get("GITLAB.SHARED_SECRET")
|
||||||
if not request.headers.get("X-Gitlab-Token") == secret:
|
if not request.headers.get("X-Gitlab-Token") == secret:
|
||||||
get_logger().error("Failed to validate secret")
|
get_logger().error("Failed to validate secret")
|
||||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content=jsonable_encoder({"message": "unauthorized"}),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().error("Failed to validate secret")
|
get_logger().error("Failed to validate secret")
|
||||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content=jsonable_encoder({"message": "unauthorized"}),
|
||||||
|
)
|
||||||
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
gitlab_token = get_settings().get("GITLAB.PERSONAL_ACCESS_TOKEN", None)
|
||||||
if not gitlab_token:
|
if not gitlab_token:
|
||||||
get_logger().error("No gitlab token found")
|
get_logger().error("No gitlab token found")
|
||||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=jsonable_encoder({"message": "unauthorized"}))
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content=jsonable_encoder({"message": "unauthorized"}),
|
||||||
|
)
|
||||||
|
|
||||||
get_logger().info("GitLab data", artifact=data)
|
get_logger().info("GitLab data", artifact=data)
|
||||||
sender = data.get("user", {}).get("username", "unknown")
|
sender = data.get("user", {}).get("username", "unknown")
|
||||||
@ -189,31 +233,49 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
|
|
||||||
# ignore bot users
|
# ignore bot users
|
||||||
if is_bot_user(data):
|
if is_bot_user(data):
|
||||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
return JSONResponse(
|
||||||
if data.get('event_type') != 'note': # not a comment
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder({"message": "success"}),
|
||||||
|
)
|
||||||
|
if data.get('event_type') != 'note': # not a comment
|
||||||
# ignore MRs based on title, labels, source and target branches
|
# ignore MRs based on title, labels, source and target branches
|
||||||
if not should_process_pr_logic(data):
|
if not should_process_pr_logic(data):
|
||||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder({"message": "success"}),
|
||||||
|
)
|
||||||
|
|
||||||
log_context["sender"] = sender
|
log_context["sender"] = sender
|
||||||
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get('action') in ['open', 'reopen']:
|
if data.get('object_kind') == 'merge_request' and data['object_attributes'].get(
|
||||||
|
'action'
|
||||||
|
) in ['open', 'reopen']:
|
||||||
title = data['object_attributes'].get('title')
|
title = data['object_attributes'].get('title')
|
||||||
url = data['object_attributes'].get('url')
|
url = data['object_attributes'].get('url')
|
||||||
draft = data['object_attributes'].get('draft')
|
draft = data['object_attributes'].get('draft')
|
||||||
get_logger().info(f"New merge request: {url}")
|
get_logger().info(f"New merge request: {url}")
|
||||||
if draft:
|
if draft:
|
||||||
get_logger().info(f"Skipping draft MR: {url}")
|
get_logger().info(f"Skipping draft MR: {url}")
|
||||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder({"message": "success"}),
|
||||||
|
)
|
||||||
|
|
||||||
await _perform_commands_gitlab("pr_commands", PRAgent(), url, log_context, data)
|
await _perform_commands_gitlab(
|
||||||
elif data.get('object_kind') == 'note' and data.get('event_type') == 'note': # comment on MR
|
"pr_commands", PRAgent(), url, log_context, data
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
data.get('object_kind') == 'note' and data.get('event_type') == 'note'
|
||||||
|
): # comment on MR
|
||||||
if 'merge_request' in data:
|
if 'merge_request' in data:
|
||||||
mr = data['merge_request']
|
mr = data['merge_request']
|
||||||
url = mr.get('url')
|
url = mr.get('url')
|
||||||
|
|
||||||
get_logger().info(f"A comment has been added to a merge request: {url}")
|
get_logger().info(f"A comment has been added to a merge request: {url}")
|
||||||
body = data.get('object_attributes', {}).get('note')
|
body = data.get('object_attributes', {}).get('note')
|
||||||
if data.get('object_attributes', {}).get('type') == 'DiffNote' and '/ask' in body: # /ask_line
|
if (
|
||||||
|
data.get('object_attributes', {}).get('type') == 'DiffNote'
|
||||||
|
and '/ask' in body
|
||||||
|
): # /ask_line
|
||||||
body = handle_ask_line(body, data)
|
body = handle_ask_line(body, data)
|
||||||
|
|
||||||
await handle_request(url, body, log_context, sender_id)
|
await handle_request(url, body, log_context, sender_id)
|
||||||
@ -221,30 +283,44 @@ async def gitlab_webhook(background_tasks: BackgroundTasks, request: Request):
|
|||||||
try:
|
try:
|
||||||
project_id = data['project_id']
|
project_id = data['project_id']
|
||||||
commit_sha = data['checkout_sha']
|
commit_sha = data['checkout_sha']
|
||||||
url = await get_mr_url_from_commit_sha(commit_sha, gitlab_token, project_id)
|
url = await get_mr_url_from_commit_sha(
|
||||||
|
commit_sha, gitlab_token, project_id
|
||||||
|
)
|
||||||
if not url:
|
if not url:
|
||||||
get_logger().info(f"No MR found for commit: {commit_sha}")
|
get_logger().info(f"No MR found for commit: {commit_sha}")
|
||||||
return JSONResponse(status_code=status.HTTP_200_OK,
|
return JSONResponse(
|
||||||
content=jsonable_encoder({"message": "success"}))
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder({"message": "success"}),
|
||||||
|
)
|
||||||
|
|
||||||
# we need first to apply_repo_settings
|
# we need first to apply_repo_settings
|
||||||
apply_repo_settings(url)
|
apply_repo_settings(url)
|
||||||
commands_on_push = get_settings().get(f"gitlab.push_commands", {})
|
commands_on_push = get_settings().get(f"gitlab.push_commands", {})
|
||||||
handle_push_trigger = get_settings().get(f"gitlab.handle_push_trigger", False)
|
handle_push_trigger = get_settings().get(
|
||||||
|
f"gitlab.handle_push_trigger", False
|
||||||
|
)
|
||||||
if not commands_on_push or not handle_push_trigger:
|
if not commands_on_push or not handle_push_trigger:
|
||||||
get_logger().info("Push event, but no push commands found or push trigger is disabled")
|
get_logger().info(
|
||||||
return JSONResponse(status_code=status.HTTP_200_OK,
|
"Push event, but no push commands found or push trigger is disabled"
|
||||||
content=jsonable_encoder({"message": "success"}))
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=jsonable_encoder({"message": "success"}),
|
||||||
|
)
|
||||||
|
|
||||||
get_logger().debug(f'A push event has been received: {url}')
|
get_logger().debug(f'A push event has been received: {url}')
|
||||||
await _perform_commands_gitlab("push_commands", PRAgent(), url, log_context, data)
|
await _perform_commands_gitlab(
|
||||||
|
"push_commands", PRAgent(), url, log_context, data
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to handle push event: {e}")
|
get_logger().error(f"Failed to handle push event: {e}")
|
||||||
|
|
||||||
background_tasks.add_task(inner, request_json)
|
background_tasks.add_task(inner, request_json)
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
get_logger().info(f"Processing time: {end_time - start_time}", request=request_json)
|
get_logger().info(f"Processing time: {end_time - start_time}", request=request_json)
|
||||||
return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"}))
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK, content=jsonable_encoder({"message": "success"})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_ask_line(body, data):
|
def handle_ask_line(body, data):
|
||||||
@ -271,6 +347,7 @@ def handle_ask_line(body, data):
|
|||||||
async def root():
|
async def root():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
gitlab_url = get_settings().get("GITLAB.URL", None)
|
gitlab_url = get_settings().get("GITLAB.URL", None)
|
||||||
if not gitlab_url:
|
if not gitlab_url:
|
||||||
raise ValueError("GITLAB.URL is not set")
|
raise ValueError("GITLAB.URL is not set")
|
||||||
|
|||||||
@ -1,18 +1,19 @@
|
|||||||
class HelpMessage:
|
class HelpMessage:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_general_commands_text():
|
def get_general_commands_text():
|
||||||
commands_text = "> - **/review**: Request a review of your Pull Request. \n" \
|
commands_text = (
|
||||||
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n" \
|
"> - **/review**: Request a review of your Pull Request. \n"
|
||||||
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n" \
|
"> - **/describe**: Update the PR title and description based on the contents of the PR. \n"
|
||||||
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n" \
|
"> - **/improve [--extended]**: Suggest code improvements. Extended mode provides a higher quality feedback. \n"
|
||||||
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n" \
|
"> - **/ask \\<QUESTION\\>**: Ask a question about the PR. \n"
|
||||||
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n" \
|
"> - **/update_changelog**: Update the changelog based on the PR's contents. \n"
|
||||||
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n" \
|
"> - **/add_docs** 💎: Generate docstring for new components introduced in the PR. \n"
|
||||||
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n" \
|
"> - **/generate_labels** 💎: Generate labels for the PR based on the PR's contents. \n"
|
||||||
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n" \
|
"> - **/analyze** 💎: Automatically analyzes the PR, and presents changes walkthrough for each component. \n\n"
|
||||||
">To list the possible configuration parameters, add a **/config** comment. \n"
|
">See the [tools guide](https://pr-agent-docs.codium.ai/tools/) for more details.\n"
|
||||||
return commands_text
|
">To list the possible configuration parameters, add a **/config** comment. \n"
|
||||||
|
)
|
||||||
|
return commands_text
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_general_bot_help_text():
|
def get_general_bot_help_text():
|
||||||
@ -21,10 +22,12 @@ class HelpMessage:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_review_usage_guide():
|
def get_review_usage_guide():
|
||||||
output ="**Overview:**\n"
|
output = "**Overview:**\n"
|
||||||
output +=("The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
|
output += (
|
||||||
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n")
|
"The `review` tool scans the PR code changes, and generates a PR review which includes several types of feedbacks, such as possible PR issues, security threats and relevant test in the PR. More feedbacks can be [added](https://pr-agent-docs.codium.ai/tools/review/#general-configurations) by configuring the tool.\n\n"
|
||||||
output +="""\
|
"The tool can be triggered [automatically](https://pr-agent-docs.codium.ai/usage-guide/automations_and_usage/#github-app-automatic-tools-when-a-new-pr-is-opened) every time a new PR is opened, or can be invoked manually by commenting on any PR.\n"
|
||||||
|
)
|
||||||
|
output += """\
|
||||||
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template:
|
- When commenting, to edit [configurations](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L23) related to the review tool (`pr_reviewer` section), use the following template:
|
||||||
```
|
```
|
||||||
/review --pr_reviewer.some_config1=... --pr_reviewer.some_config2=...
|
/review --pr_reviewer.some_config1=... --pr_reviewer.some_config2=...
|
||||||
@ -41,8 +44,6 @@ some_config2=...
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_describe_usage_guide():
|
def get_describe_usage_guide():
|
||||||
output = "**Overview:**\n"
|
output = "**Overview:**\n"
|
||||||
@ -137,7 +138,6 @@ Use triple quotes to write multi-line instructions. Use bullet points to make th
|
|||||||
'''
|
'''
|
||||||
output += "\n\n</details></td></tr>\n\n"
|
output += "\n\n</details></td></tr>\n\n"
|
||||||
|
|
||||||
|
|
||||||
# general
|
# general
|
||||||
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
|
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
|
||||||
output += HelpMessage.get_general_bot_help_text()
|
output += HelpMessage.get_general_bot_help_text()
|
||||||
@ -175,7 +175,6 @@ You can ask questions about the entire PR, about specific code lines, or about a
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_improve_usage_guide():
|
def get_improve_usage_guide():
|
||||||
output = "**Overview:**\n"
|
output = "**Overview:**\n"
|
||||||
|
|||||||
@ -18,8 +18,12 @@ def verify_signature(payload_body, secret_token, signature_header):
|
|||||||
signature_header: header received from GitHub (x-hub-signature-256)
|
signature_header: header received from GitHub (x-hub-signature-256)
|
||||||
"""
|
"""
|
||||||
if not signature_header:
|
if not signature_header:
|
||||||
raise HTTPException(status_code=403, detail="x-hub-signature-256 header is missing!")
|
raise HTTPException(
|
||||||
hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256)
|
status_code=403, detail="x-hub-signature-256 header is missing!"
|
||||||
|
)
|
||||||
|
hash_object = hmac.new(
|
||||||
|
secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256
|
||||||
|
)
|
||||||
expected_signature = "sha256=" + hash_object.hexdigest()
|
expected_signature = "sha256=" + hash_object.hexdigest()
|
||||||
if not hmac.compare_digest(expected_signature, signature_header):
|
if not hmac.compare_digest(expected_signature, signature_header):
|
||||||
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
|
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
|
||||||
@ -27,6 +31,7 @@ def verify_signature(payload_body, secret_token, signature_header):
|
|||||||
|
|
||||||
class RateLimitExceeded(Exception):
|
class RateLimitExceeded(Exception):
|
||||||
"""Raised when the git provider API rate limit has been exceeded."""
|
"""Raised when the git provider API rate limit has been exceeded."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +71,11 @@ class DefaultDictWithTimeout(defaultdict):
|
|||||||
request_time = self.__time()
|
request_time = self.__time()
|
||||||
if request_time - self.__last_refresh > self.__refresh_interval:
|
if request_time - self.__last_refresh > self.__refresh_interval:
|
||||||
return
|
return
|
||||||
to_delete = [key for key, key_time in self.__key_times.items() if request_time - key_time > self.__ttl]
|
to_delete = [
|
||||||
|
key
|
||||||
|
for key, key_time in self.__key_times.items()
|
||||||
|
if request_time - key_time > self.__ttl
|
||||||
|
]
|
||||||
for key in to_delete:
|
for key in to_delete:
|
||||||
del self[key]
|
del self[key]
|
||||||
self.__last_refresh = request_time
|
self.__last_refresh = request_time
|
||||||
|
|||||||
@ -17,9 +17,13 @@ from utils.pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRAddDocs:
|
class PRAddDocs:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
def __init__(
|
||||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
cli_mode=False,
|
||||||
|
args: list = None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
):
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
@ -39,13 +43,16 @@ class PRAddDocs:
|
|||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"extra_instructions": get_settings().pr_add_docs.extra_instructions,
|
"extra_instructions": get_settings().pr_add_docs.extra_instructions,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
'docs_for_language': get_docs_for_language(self.main_language,
|
'docs_for_language': get_docs_for_language(
|
||||||
get_settings().pr_add_docs.docs_style),
|
self.main_language, get_settings().pr_add_docs.docs_style
|
||||||
|
),
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(
|
||||||
self.vars,
|
self.git_provider.pr,
|
||||||
get_settings().pr_add_docs_prompt.system,
|
self.vars,
|
||||||
get_settings().pr_add_docs_prompt.user)
|
get_settings().pr_add_docs_prompt.system,
|
||||||
|
get_settings().pr_add_docs_prompt.user,
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
@ -66,16 +73,20 @@ class PRAddDocs:
|
|||||||
get_logger().info('Pushing inline code documentation...')
|
get_logger().info('Pushing inline code documentation...')
|
||||||
self.push_inline_docs(data)
|
self.push_inline_docs(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to generate code documentation for PR, error: {e}")
|
get_logger().error(
|
||||||
|
f"Failed to generate code documentation for PR, error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str):
|
async def _prepare_prediction(self, model: str):
|
||||||
get_logger().info('Getting PR diff...')
|
get_logger().info('Getting PR diff...')
|
||||||
|
|
||||||
self.patches_diff = get_pr_diff(self.git_provider,
|
self.patches_diff = get_pr_diff(
|
||||||
self.token_handler,
|
self.git_provider,
|
||||||
model,
|
self.token_handler,
|
||||||
add_line_numbers_to_hunks=True,
|
model,
|
||||||
disable_extra_lines=False)
|
add_line_numbers_to_hunks=True,
|
||||||
|
disable_extra_lines=False,
|
||||||
|
)
|
||||||
|
|
||||||
get_logger().info('Getting AI prediction...')
|
get_logger().info('Getting AI prediction...')
|
||||||
self.prediction = await self._get_prediction(model)
|
self.prediction = await self._get_prediction(model)
|
||||||
@ -84,13 +95,21 @@ class PRAddDocs:
|
|||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(get_settings().pr_add_docs_prompt.system).render(variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().pr_add_docs_prompt.user).render(variables)
|
get_settings().pr_add_docs_prompt.system
|
||||||
|
).render(variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().pr_add_docs_prompt.user
|
||||||
|
).render(variables)
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
model=model,
|
||||||
|
temperature=get_settings().config.temperature,
|
||||||
|
system=system_prompt,
|
||||||
|
user=user_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -105,7 +124,9 @@ class PRAddDocs:
|
|||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
if not data['Code Documentation']:
|
if not data['Code Documentation']:
|
||||||
return self.git_provider.publish_comment('No code documentation found to improve this PR.')
|
return self.git_provider.publish_comment(
|
||||||
|
'No code documentation found to improve this PR.'
|
||||||
|
)
|
||||||
|
|
||||||
for d in data['Code Documentation']:
|
for d in data['Code Documentation']:
|
||||||
try:
|
try:
|
||||||
@ -116,32 +137,59 @@ class PRAddDocs:
|
|||||||
documentation = d['documentation']
|
documentation = d['documentation']
|
||||||
doc_placement = d['doc placement'].strip()
|
doc_placement = d['doc placement'].strip()
|
||||||
if documentation:
|
if documentation:
|
||||||
new_code_snippet = self.dedent_code(relevant_file, relevant_line, documentation, doc_placement,
|
new_code_snippet = self.dedent_code(
|
||||||
add_original_line=True)
|
relevant_file,
|
||||||
|
relevant_line,
|
||||||
|
documentation,
|
||||||
|
doc_placement,
|
||||||
|
add_original_line=True,
|
||||||
|
)
|
||||||
|
|
||||||
body = f"**Suggestion:** Proposed documentation\n```suggestion\n" + new_code_snippet + "\n```"
|
body = (
|
||||||
docs.append({'body': body, 'relevant_file': relevant_file,
|
f"**Suggestion:** Proposed documentation\n```suggestion\n"
|
||||||
'relevant_lines_start': relevant_line,
|
+ new_code_snippet
|
||||||
'relevant_lines_end': relevant_line})
|
+ "\n```"
|
||||||
|
)
|
||||||
|
docs.append(
|
||||||
|
{
|
||||||
|
'body': body,
|
||||||
|
'relevant_file': relevant_file,
|
||||||
|
'relevant_lines_start': relevant_line,
|
||||||
|
'relevant_lines_end': relevant_line,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Could not parse code docs: {d}")
|
get_logger().info(f"Could not parse code docs: {d}")
|
||||||
|
|
||||||
is_successful = self.git_provider.publish_code_suggestions(docs)
|
is_successful = self.git_provider.publish_code_suggestions(docs)
|
||||||
if not is_successful:
|
if not is_successful:
|
||||||
get_logger().info("Failed to publish code docs, trying to publish each docs separately")
|
get_logger().info(
|
||||||
|
"Failed to publish code docs, trying to publish each docs separately"
|
||||||
|
)
|
||||||
for doc_suggestion in docs:
|
for doc_suggestion in docs:
|
||||||
self.git_provider.publish_code_suggestions([doc_suggestion])
|
self.git_provider.publish_code_suggestions([doc_suggestion])
|
||||||
|
|
||||||
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet, doc_placement='after',
|
def dedent_code(
|
||||||
add_original_line=False):
|
self,
|
||||||
|
relevant_file,
|
||||||
|
relevant_lines_start,
|
||||||
|
new_code_snippet,
|
||||||
|
doc_placement='after',
|
||||||
|
add_original_line=False,
|
||||||
|
):
|
||||||
try: # dedent code snippet
|
try: # dedent code snippet
|
||||||
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
|
self.diff_files = (
|
||||||
|
self.git_provider.diff_files
|
||||||
|
if self.git_provider.diff_files
|
||||||
else self.git_provider.get_diff_files()
|
else self.git_provider.get_diff_files()
|
||||||
|
)
|
||||||
original_initial_line = None
|
original_initial_line = None
|
||||||
for file in self.diff_files:
|
for file in self.diff_files:
|
||||||
if file.filename.strip() == relevant_file:
|
if file.filename.strip() == relevant_file:
|
||||||
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
|
original_initial_line = file.head_file.splitlines()[
|
||||||
|
relevant_lines_start - 1
|
||||||
|
]
|
||||||
break
|
break
|
||||||
if original_initial_line:
|
if original_initial_line:
|
||||||
if doc_placement == 'after':
|
if doc_placement == 'after':
|
||||||
@ -150,18 +198,28 @@ class PRAddDocs:
|
|||||||
line = original_initial_line
|
line = original_initial_line
|
||||||
suggested_initial_line = new_code_snippet.splitlines()[0]
|
suggested_initial_line = new_code_snippet.splitlines()[0]
|
||||||
original_initial_spaces = len(line) - len(line.lstrip())
|
original_initial_spaces = len(line) - len(line.lstrip())
|
||||||
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
|
suggested_initial_spaces = len(suggested_initial_line) - len(
|
||||||
|
suggested_initial_line.lstrip()
|
||||||
|
)
|
||||||
delta_spaces = original_initial_spaces - suggested_initial_spaces
|
delta_spaces = original_initial_spaces - suggested_initial_spaces
|
||||||
if delta_spaces > 0:
|
if delta_spaces > 0:
|
||||||
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
|
new_code_snippet = textwrap.indent(
|
||||||
|
new_code_snippet, delta_spaces * " "
|
||||||
|
).rstrip('\n')
|
||||||
if add_original_line:
|
if add_original_line:
|
||||||
if doc_placement == 'after':
|
if doc_placement == 'after':
|
||||||
new_code_snippet = original_initial_line + "\n" + new_code_snippet
|
new_code_snippet = (
|
||||||
|
original_initial_line + "\n" + new_code_snippet
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new_code_snippet = new_code_snippet.rstrip() + "\n" + original_initial_line
|
new_code_snippet = (
|
||||||
|
new_code_snippet.rstrip() + "\n" + original_initial_line
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
get_logger().info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
|
get_logger().info(
|
||||||
|
f"Could not dedent code snippet for file {relevant_file}, error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return new_code_snippet
|
return new_code_snippet
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ class PRConfig:
|
|||||||
"""
|
"""
|
||||||
The PRConfig class is responsible for listing all configuration options available for the user.
|
The PRConfig class is responsible for listing all configuration options available for the user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pr_url: str, args=None, ai_handler=None):
|
def __init__(self, pr_url: str, args=None, ai_handler=None):
|
||||||
"""
|
"""
|
||||||
Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request.
|
Initialize the PRConfig object with the necessary attributes and objects to comment on a pull request.
|
||||||
@ -34,20 +35,43 @@ class PRConfig:
|
|||||||
conf_settings = Dynaconf(settings_files=[conf_file])
|
conf_settings = Dynaconf(settings_files=[conf_file])
|
||||||
configuration_headers = [header.lower() for header in conf_settings.keys()]
|
configuration_headers = [header.lower() for header in conf_settings.keys()]
|
||||||
relevant_configs = {
|
relevant_configs = {
|
||||||
header: configs for header, configs in get_settings().to_dict().items()
|
header: configs
|
||||||
if (header.lower().startswith("pr_") or header.lower().startswith("config")) and header.lower() in configuration_headers
|
for header, configs in get_settings().to_dict().items()
|
||||||
|
if (header.lower().startswith("pr_") or header.lower().startswith("config"))
|
||||||
|
and header.lower() in configuration_headers
|
||||||
}
|
}
|
||||||
|
|
||||||
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
|
skip_keys = [
|
||||||
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS',
|
'ai_disclaimer',
|
||||||
'APP_NAME', 'PERSONAL_ACCESS_TOKEN', 'shared_secret', 'key', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'user_token',
|
'ai_disclaimer_title',
|
||||||
'private_key', 'private_key_id', 'client_id', 'client_secret', 'token', 'bearer_token']
|
'ANALYTICS_FOLDER',
|
||||||
|
'secret_provider',
|
||||||
|
"skip_keys",
|
||||||
|
"app_id",
|
||||||
|
"redirect",
|
||||||
|
'trial_prefix_message',
|
||||||
|
'no_eligible_message',
|
||||||
|
'identity_provider',
|
||||||
|
'ALLOWED_REPOS',
|
||||||
|
'APP_NAME',
|
||||||
|
'PERSONAL_ACCESS_TOKEN',
|
||||||
|
'shared_secret',
|
||||||
|
'key',
|
||||||
|
'AWS_ACCESS_KEY_ID',
|
||||||
|
'AWS_SECRET_ACCESS_KEY',
|
||||||
|
'user_token',
|
||||||
|
'private_key',
|
||||||
|
'private_key_id',
|
||||||
|
'client_id',
|
||||||
|
'client_secret',
|
||||||
|
'token',
|
||||||
|
'bearer_token',
|
||||||
|
]
|
||||||
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
|
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
|
||||||
if extra_skip_keys:
|
if extra_skip_keys:
|
||||||
skip_keys.extend(extra_skip_keys)
|
skip_keys.extend(extra_skip_keys)
|
||||||
skip_keys_lower = [key.lower() for key in skip_keys]
|
skip_keys_lower = [key.lower() for key in skip_keys]
|
||||||
|
|
||||||
|
|
||||||
markdown_text = "<details> <summary><strong>🛠️ PR-Agent Configurations:</strong></summary> \n\n"
|
markdown_text = "<details> <summary><strong>🛠️ PR-Agent Configurations:</strong></summary> \n\n"
|
||||||
markdown_text += f"\n\n```yaml\n\n"
|
markdown_text += f"\n\n```yaml\n\n"
|
||||||
for header, configs in relevant_configs.items():
|
for header, configs in relevant_configs.items():
|
||||||
@ -61,5 +85,7 @@ class PRConfig:
|
|||||||
markdown_text += " "
|
markdown_text += " "
|
||||||
markdown_text += "\n```"
|
markdown_text += "\n```"
|
||||||
markdown_text += "\n</details>\n"
|
markdown_text += "\n</details>\n"
|
||||||
get_logger().info(f"Possible Configurations outputted to PR comment", artifact=markdown_text)
|
get_logger().info(
|
||||||
|
f"Possible Configurations outputted to PR comment", artifact=markdown_text
|
||||||
|
)
|
||||||
return markdown_text
|
return markdown_text
|
||||||
|
|||||||
@ -10,27 +10,38 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
|
|
||||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
from utils.pr_agent.algo.pr_processing import (OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
|
from utils.pr_agent.algo.pr_processing import (
|
||||||
get_pr_diff,
|
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
|
||||||
get_pr_diff_multiple_patchs,
|
get_pr_diff,
|
||||||
retry_with_fallback_models)
|
get_pr_diff_multiple_patchs,
|
||||||
|
retry_with_fallback_models,
|
||||||
|
)
|
||||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||||
from utils.pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens,
|
from utils.pr_agent.algo.utils import (
|
||||||
get_max_tokens, get_user_labels, load_yaml,
|
ModelType,
|
||||||
set_custom_labels,
|
PRDescriptionHeader,
|
||||||
show_relevant_configurations)
|
clip_tokens,
|
||||||
|
get_max_tokens,
|
||||||
|
get_user_labels,
|
||||||
|
load_yaml,
|
||||||
|
set_custom_labels,
|
||||||
|
show_relevant_configurations,
|
||||||
|
)
|
||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
from utils.pr_agent.git_providers import (GithubProvider, get_git_provider_with_context)
|
from utils.pr_agent.git_providers import GithubProvider, get_git_provider_with_context
|
||||||
from utils.pr_agent.git_providers.git_provider import get_main_pr_language
|
from utils.pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
from utils.pr_agent.log import get_logger
|
from utils.pr_agent.log import get_logger
|
||||||
from utils.pr_agent.servers.help import HelpMessage
|
from utils.pr_agent.servers.help import HelpMessage
|
||||||
from utils.pr_agent.tools.ticket_pr_compliance_check import (
|
from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
|
||||||
extract_and_cache_pr_tickets)
|
|
||||||
|
|
||||||
|
|
||||||
class PRDescription:
|
class PRDescription:
|
||||||
def __init__(self, pr_url: str, args: list = None,
|
def __init__(
|
||||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
args: list = None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
||||||
using an AI model.
|
using an AI model.
|
||||||
@ -44,11 +55,22 @@ class PRDescription:
|
|||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
self.pr_id = self.git_provider.get_pr_id()
|
self.pr_id = self.git_provider.get_pr_id()
|
||||||
self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"]
|
self.keys_fix = [
|
||||||
|
"filename:",
|
||||||
|
"language:",
|
||||||
|
"changes_summary:",
|
||||||
|
"changes_title:",
|
||||||
|
"description:",
|
||||||
|
"title:",
|
||||||
|
]
|
||||||
|
|
||||||
if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported(
|
if (
|
||||||
"gfm_markdown"):
|
get_settings().pr_description.enable_semantic_files_types
|
||||||
get_logger().debug(f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported.")
|
and not self.git_provider.is_supported("gfm_markdown")
|
||||||
|
):
|
||||||
|
get_logger().debug(
|
||||||
|
f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported."
|
||||||
|
)
|
||||||
get_settings().pr_description.enable_semantic_files_types = False
|
get_settings().pr_description.enable_semantic_files_types = False
|
||||||
|
|
||||||
# Initialize the AI handler
|
# Initialize the AI handler
|
||||||
@ -56,7 +78,9 @@ class PRDescription:
|
|||||||
self.ai_handler.main_pr_language = self.main_pr_language
|
self.ai_handler.main_pr_language = self.main_pr_language
|
||||||
|
|
||||||
# Initialize the variables dictionary
|
# Initialize the variables dictionary
|
||||||
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get("collapsible_file_list_threshold", 8)
|
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get(
|
||||||
|
"collapsible_file_list_threshold", 8
|
||||||
|
)
|
||||||
self.vars = {
|
self.vars = {
|
||||||
"title": self.git_provider.pr.title,
|
"title": self.git_provider.pr.title,
|
||||||
"branch": self.git_provider.get_pr_branch(),
|
"branch": self.git_provider.get_pr_branch(),
|
||||||
@ -69,8 +93,11 @@ class PRDescription:
|
|||||||
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
|
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
|
||||||
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
|
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
|
||||||
"related_tickets": "",
|
"related_tickets": "",
|
||||||
"include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
|
"include_file_summary_changes": len(self.git_provider.get_diff_files())
|
||||||
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
|
<= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
|
||||||
|
'duplicate_prompt_examples': get_settings().config.get(
|
||||||
|
'duplicate_prompt_examples', False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.user_description = self.git_provider.get_user_description()
|
self.user_description = self.git_provider.get_user_description()
|
||||||
@ -91,10 +118,14 @@ class PRDescription:
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}")
|
get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}")
|
||||||
relevant_configs = {'pr_description': dict(get_settings().pr_description),
|
relevant_configs = {
|
||||||
'config': dict(get_settings().config)}
|
'pr_description': dict(get_settings().pr_description),
|
||||||
|
'config': dict(get_settings().config),
|
||||||
|
}
|
||||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||||
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
|
if get_settings().config.publish_output and not get_settings().config.get(
|
||||||
|
'is_auto_command', False
|
||||||
|
):
|
||||||
self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True)
|
self.git_provider.publish_comment("准备 PR 描述中...", is_temporary=True)
|
||||||
|
|
||||||
# ticket extraction if exists
|
# ticket extraction if exists
|
||||||
@ -119,40 +150,73 @@ class PRDescription:
|
|||||||
get_logger().debug(f"Publishing labels disabled")
|
get_logger().debug(f"Publishing labels disabled")
|
||||||
|
|
||||||
if get_settings().pr_description.use_description_markers:
|
if get_settings().pr_description.use_description_markers:
|
||||||
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer_with_markers()
|
(
|
||||||
|
pr_title,
|
||||||
|
pr_body,
|
||||||
|
changes_walkthrough,
|
||||||
|
pr_file_changes,
|
||||||
|
) = self._prepare_pr_answer_with_markers()
|
||||||
else:
|
else:
|
||||||
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer()
|
(
|
||||||
if not self.git_provider.is_supported(
|
pr_title,
|
||||||
"publish_file_comments") or not get_settings().pr_description.inline_file_summary:
|
pr_body,
|
||||||
|
changes_walkthrough,
|
||||||
|
pr_file_changes,
|
||||||
|
) = self._prepare_pr_answer()
|
||||||
|
if (
|
||||||
|
not self.git_provider.is_supported("publish_file_comments")
|
||||||
|
or not get_settings().pr_description.inline_file_summary
|
||||||
|
):
|
||||||
pr_body += "\n\n" + changes_walkthrough
|
pr_body += "\n\n" + changes_walkthrough
|
||||||
get_logger().debug("PR output", artifact={"title": pr_title, "body": pr_body})
|
get_logger().debug(
|
||||||
|
"PR output", artifact={"title": pr_title, "body": pr_body}
|
||||||
|
)
|
||||||
|
|
||||||
# Add help text if gfm_markdown is supported
|
# Add help text if gfm_markdown is supported
|
||||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_description.enable_help_text:
|
if (
|
||||||
|
self.git_provider.is_supported("gfm_markdown")
|
||||||
|
and get_settings().pr_description.enable_help_text
|
||||||
|
):
|
||||||
pr_body += "<hr>\n\n<details> <summary><strong>✨ 工具使用指南:</strong></summary><hr> \n\n"
|
pr_body += "<hr>\n\n<details> <summary><strong>✨ 工具使用指南:</strong></summary><hr> \n\n"
|
||||||
pr_body += HelpMessage.get_describe_usage_guide()
|
pr_body += HelpMessage.get_describe_usage_guide()
|
||||||
pr_body += "\n</details>\n"
|
pr_body += "\n</details>\n"
|
||||||
elif get_settings().pr_description.enable_help_comment and self.git_provider.is_supported("gfm_markdown"):
|
elif (
|
||||||
|
get_settings().pr_description.enable_help_comment
|
||||||
|
and self.git_provider.is_supported("gfm_markdown")
|
||||||
|
):
|
||||||
if isinstance(self.git_provider, GithubProvider):
|
if isinstance(self.git_provider, GithubProvider):
|
||||||
pr_body += ('\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> '
|
pr_body += (
|
||||||
'关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 '
|
'\n\n___\n\n> <details> <summary> 需要帮助?</summary><li>Type <code>/help 如何 ...</code> '
|
||||||
'<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> '
|
'关于PR-Agent使用的任何问题,请在评论区留言.</li><li>查看一下 '
|
||||||
'了解更多.</li></details>')
|
'<a href="https://qodo-merge-docs.qodo.ai/usage-guide/">documentation</a> '
|
||||||
else: # gitlab
|
'了解更多.</li></details>'
|
||||||
pr_body += ("\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 "
|
)
|
||||||
"关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 "
|
else: # gitlab
|
||||||
"<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>")
|
pr_body += (
|
||||||
|
"\n\n___\n\n<details><summary>需要帮助?</summary>- Type <code>/help 如何 ...</code> 在评论中 "
|
||||||
|
"关于PR-Agent使用的任何问题请在此发帖. <br>- 查看一下 "
|
||||||
|
"<a href='https://qodo-merge-docs.qodo.ai/usage-guide/'>documentation</a> 了解更多.</details>"
|
||||||
|
)
|
||||||
# elif get_settings().pr_description.enable_help_comment:
|
# elif get_settings().pr_description.enable_help_comment:
|
||||||
# pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information'
|
# pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information'
|
||||||
|
|
||||||
# Output the relevant configurations if enabled
|
# Output the relevant configurations if enabled
|
||||||
if get_settings().get('config', {}).get('output_relevant_configurations', False):
|
if (
|
||||||
pr_body += show_relevant_configurations(relevant_section='pr_description')
|
get_settings()
|
||||||
|
.get('config', {})
|
||||||
|
.get('output_relevant_configurations', False)
|
||||||
|
):
|
||||||
|
pr_body += show_relevant_configurations(
|
||||||
|
relevant_section='pr_description'
|
||||||
|
)
|
||||||
|
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
|
|
||||||
# publish labels
|
# publish labels
|
||||||
if get_settings().pr_description.publish_labels and pr_labels and self.git_provider.is_supported("get_labels"):
|
if (
|
||||||
|
get_settings().pr_description.publish_labels
|
||||||
|
and pr_labels
|
||||||
|
and self.git_provider.is_supported("get_labels")
|
||||||
|
):
|
||||||
original_labels = self.git_provider.get_pr_labels(update=True)
|
original_labels = self.git_provider.get_pr_labels(update=True)
|
||||||
get_logger().debug(f"original labels", artifact=original_labels)
|
get_logger().debug(f"original labels", artifact=original_labels)
|
||||||
user_labels = get_user_labels(original_labels)
|
user_labels = get_user_labels(original_labels)
|
||||||
@ -165,20 +229,29 @@ class PRDescription:
|
|||||||
|
|
||||||
# publish description
|
# publish description
|
||||||
if get_settings().pr_description.publish_description_as_comment:
|
if get_settings().pr_description.publish_description_as_comment:
|
||||||
full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
|
full_markdown_description = (
|
||||||
if get_settings().pr_description.publish_description_as_comment_persistent:
|
f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
|
||||||
self.git_provider.publish_persistent_comment(full_markdown_description,
|
)
|
||||||
initial_header="## Title",
|
if (
|
||||||
update_header=True,
|
get_settings().pr_description.publish_description_as_comment_persistent
|
||||||
name="describe",
|
):
|
||||||
final_update_message=False, )
|
self.git_provider.publish_persistent_comment(
|
||||||
|
full_markdown_description,
|
||||||
|
initial_header="## Title",
|
||||||
|
update_header=True,
|
||||||
|
name="describe",
|
||||||
|
final_update_message=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.git_provider.publish_comment(full_markdown_description)
|
self.git_provider.publish_comment(full_markdown_description)
|
||||||
else:
|
else:
|
||||||
self.git_provider.publish_description(pr_title, pr_body)
|
self.git_provider.publish_description(pr_title, pr_body)
|
||||||
|
|
||||||
# publish final update message
|
# publish final update message
|
||||||
if (get_settings().pr_description.final_update_message and not get_settings().config.get('is_auto_command', False)):
|
if (
|
||||||
|
get_settings().pr_description.final_update_message
|
||||||
|
and not get_settings().config.get('is_auto_command', False)
|
||||||
|
):
|
||||||
latest_commit_url = self.git_provider.get_latest_commit_url()
|
latest_commit_url = self.git_provider.get_latest_commit_url()
|
||||||
if latest_commit_url:
|
if latest_commit_url:
|
||||||
pr_url = self.git_provider.get_pr_url()
|
pr_url = self.git_provider.get_pr_url()
|
||||||
@ -186,22 +259,40 @@ class PRDescription:
|
|||||||
self.git_provider.publish_comment(update_comment)
|
self.git_provider.publish_comment(update_comment)
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
else:
|
else:
|
||||||
get_logger().info('PR description, but not published since publish_output is False.')
|
get_logger().info(
|
||||||
|
'PR description, but not published since publish_output is False.'
|
||||||
|
)
|
||||||
get_settings().data = {"artifact": pr_body}
|
get_settings().data = {"artifact": pr_body}
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error generating PR description {self.pr_id}: {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Error generating PR description {self.pr_id}: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str) -> None:
|
async def _prepare_prediction(self, model: str) -> None:
|
||||||
if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
|
if (
|
||||||
get_logger().info("Markers were enabled, but user description does not contain markers. skipping AI prediction")
|
get_settings().pr_description.use_description_markers
|
||||||
|
and 'pr_agent:' not in self.user_description
|
||||||
|
):
|
||||||
|
get_logger().info(
|
||||||
|
"Markers were enabled, but user description does not contain markers. skipping AI prediction"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
|
large_pr_handling = (
|
||||||
output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
|
get_settings().pr_description.enable_large_pr_handling
|
||||||
|
and "pr_description_only_files_prompts" in get_settings()
|
||||||
|
)
|
||||||
|
output = get_pr_diff(
|
||||||
|
self.git_provider,
|
||||||
|
self.token_handler,
|
||||||
|
model,
|
||||||
|
large_pr_handling=large_pr_handling,
|
||||||
|
return_remaining_files=True,
|
||||||
|
)
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
patches_diff, remaining_files_list = output
|
patches_diff, remaining_files_list = output
|
||||||
else:
|
else:
|
||||||
@ -213,14 +304,18 @@ class PRDescription:
|
|||||||
if patches_diff:
|
if patches_diff:
|
||||||
# generate the prediction
|
# generate the prediction
|
||||||
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
||||||
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
|
self.prediction = await self._get_prediction(
|
||||||
|
model, patches_diff, prompt="pr_description_prompt"
|
||||||
|
)
|
||||||
|
|
||||||
# extend the prediction with additional files not shown
|
# extend the prediction with additional files not shown
|
||||||
if get_settings().pr_description.enable_semantic_files_types:
|
if get_settings().pr_description.enable_semantic_files_types:
|
||||||
self.prediction = await self.extend_uncovered_files(self.prediction)
|
self.prediction = await self.extend_uncovered_files(self.prediction)
|
||||||
else:
|
else:
|
||||||
get_logger().error(f"Error getting PR diff {self.pr_id}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Error getting PR diff {self.pr_id}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
else:
|
else:
|
||||||
# get the diff in multiple patches, with the token handler only for the files prompt
|
# get the diff in multiple patches, with the token handler only for the files prompt
|
||||||
@ -231,9 +326,16 @@ class PRDescription:
|
|||||||
get_settings().pr_description_only_files_prompts.system,
|
get_settings().pr_description_only_files_prompts.system,
|
||||||
get_settings().pr_description_only_files_prompts.user,
|
get_settings().pr_description_only_files_prompts.user,
|
||||||
)
|
)
|
||||||
(patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict,
|
(
|
||||||
files_in_patches_list) = get_pr_diff_multiple_patchs(
|
patches_compressed_list,
|
||||||
self.git_provider, token_handler_only_files_prompt, model)
|
total_tokens_list,
|
||||||
|
deleted_files_list,
|
||||||
|
remaining_files_list,
|
||||||
|
file_dict,
|
||||||
|
files_in_patches_list,
|
||||||
|
) = get_pr_diff_multiple_patchs(
|
||||||
|
self.git_provider, token_handler_only_files_prompt, model
|
||||||
|
)
|
||||||
|
|
||||||
# get the files prediction for each patch
|
# get the files prediction for each patch
|
||||||
if not get_settings().pr_description.async_ai_calls:
|
if not get_settings().pr_description.async_ai_calls:
|
||||||
@ -241,8 +343,9 @@ class PRDescription:
|
|||||||
for i, patches in enumerate(patches_compressed_list): # sync calls
|
for i, patches in enumerate(patches_compressed_list): # sync calls
|
||||||
patches_diff = "\n".join(patches)
|
patches_diff = "\n".join(patches)
|
||||||
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
||||||
prediction_files = await self._get_prediction(model, patches_diff,
|
prediction_files = await self._get_prediction(
|
||||||
prompt="pr_description_only_files_prompts")
|
model, patches_diff, prompt="pr_description_only_files_prompts"
|
||||||
|
)
|
||||||
results.append(prediction_files)
|
results.append(prediction_files)
|
||||||
else: # async calls
|
else: # async calls
|
||||||
tasks = []
|
tasks = []
|
||||||
@ -251,34 +354,52 @@ class PRDescription:
|
|||||||
patches_diff = "\n".join(patches)
|
patches_diff = "\n".join(patches)
|
||||||
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
|
self._get_prediction(
|
||||||
|
model,
|
||||||
|
patches_diff,
|
||||||
|
prompt="pr_description_only_files_prompts",
|
||||||
|
)
|
||||||
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
# Wait for all tasks to complete
|
# Wait for all tasks to complete
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
file_description_str_list = []
|
file_description_str_list = []
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
|
prediction_files = (
|
||||||
if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'):
|
result.strip().removeprefix('```yaml').strip('`').strip()
|
||||||
prediction_files = prediction_files.removeprefix('pr_files:').strip()
|
)
|
||||||
|
if load_yaml(
|
||||||
|
prediction_files, keys_fix_yaml=self.keys_fix
|
||||||
|
) and prediction_files.startswith('pr_files'):
|
||||||
|
prediction_files = prediction_files.removeprefix(
|
||||||
|
'pr_files:'
|
||||||
|
).strip()
|
||||||
file_description_str_list.append(prediction_files)
|
file_description_str_list.append(prediction_files)
|
||||||
else:
|
else:
|
||||||
get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files")
|
get_logger().debug(
|
||||||
|
f"failed to generate predictions in iteration {i + 1} for describe files"
|
||||||
|
)
|
||||||
|
|
||||||
# generate files_walkthrough string, with proper token handling
|
# generate files_walkthrough string, with proper token handling
|
||||||
token_handler_only_description_prompt = TokenHandler(
|
token_handler_only_description_prompt = TokenHandler(
|
||||||
self.git_provider.pr,
|
self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
get_settings().pr_description_only_description_prompts.system,
|
get_settings().pr_description_only_description_prompts.system,
|
||||||
get_settings().pr_description_only_description_prompts.user)
|
get_settings().pr_description_only_description_prompts.user,
|
||||||
|
)
|
||||||
files_walkthrough = "\n".join(file_description_str_list)
|
files_walkthrough = "\n".join(file_description_str_list)
|
||||||
files_walkthrough_prompt = copy.deepcopy(files_walkthrough)
|
files_walkthrough_prompt = copy.deepcopy(files_walkthrough)
|
||||||
MAX_EXTRA_FILES_TO_PROMPT = 50
|
MAX_EXTRA_FILES_TO_PROMPT = 50
|
||||||
if remaining_files_list:
|
if remaining_files_list:
|
||||||
files_walkthrough_prompt += "\n\nNo more token budget. Additional unprocessed files:"
|
files_walkthrough_prompt += (
|
||||||
|
"\n\nNo more token budget. Additional unprocessed files:"
|
||||||
|
)
|
||||||
for i, file in enumerate(remaining_files_list):
|
for i, file in enumerate(remaining_files_list):
|
||||||
files_walkthrough_prompt += f"\n- {file}"
|
files_walkthrough_prompt += f"\n- {file}"
|
||||||
if i >= MAX_EXTRA_FILES_TO_PROMPT:
|
if i >= MAX_EXTRA_FILES_TO_PROMPT:
|
||||||
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
|
get_logger().debug(
|
||||||
|
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
|
||||||
|
)
|
||||||
files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
|
files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
|
||||||
break
|
break
|
||||||
if deleted_files_list:
|
if deleted_files_list:
|
||||||
@ -286,32 +407,57 @@ class PRDescription:
|
|||||||
for i, file in enumerate(deleted_files_list):
|
for i, file in enumerate(deleted_files_list):
|
||||||
files_walkthrough_prompt += f"\n- {file}"
|
files_walkthrough_prompt += f"\n- {file}"
|
||||||
if i >= MAX_EXTRA_FILES_TO_PROMPT:
|
if i >= MAX_EXTRA_FILES_TO_PROMPT:
|
||||||
get_logger().debug(f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
|
get_logger().debug(
|
||||||
|
f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}"
|
||||||
|
)
|
||||||
files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
|
files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
|
||||||
break
|
break
|
||||||
tokens_files_walkthrough = len(
|
tokens_files_walkthrough = len(
|
||||||
token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt))
|
token_handler_only_description_prompt.encoder.encode(
|
||||||
total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough
|
files_walkthrough_prompt
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_tokens = (
|
||||||
|
token_handler_only_description_prompt.prompt_tokens
|
||||||
|
+ tokens_files_walkthrough
|
||||||
|
)
|
||||||
max_tokens_model = get_max_tokens(model)
|
max_tokens_model = get_max_tokens(model)
|
||||||
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||||
# clip files_walkthrough to git the tokens within the limit
|
# clip files_walkthrough to git the tokens within the limit
|
||||||
files_walkthrough_prompt = clip_tokens(files_walkthrough_prompt,
|
files_walkthrough_prompt = clip_tokens(
|
||||||
max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens,
|
files_walkthrough_prompt,
|
||||||
num_input_tokens=tokens_files_walkthrough)
|
max_tokens_model
|
||||||
|
- OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
|
||||||
|
- token_handler_only_description_prompt.prompt_tokens,
|
||||||
|
num_input_tokens=tokens_files_walkthrough,
|
||||||
|
)
|
||||||
|
|
||||||
# PR header inference
|
# PR header inference
|
||||||
get_logger().debug(f"PR diff only description", artifact=files_walkthrough_prompt)
|
get_logger().debug(
|
||||||
prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough_prompt,
|
f"PR diff only description", artifact=files_walkthrough_prompt
|
||||||
prompt="pr_description_only_description_prompts")
|
)
|
||||||
prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
|
prediction_headers = await self._get_prediction(
|
||||||
|
model,
|
||||||
|
patches_diff=files_walkthrough_prompt,
|
||||||
|
prompt="pr_description_only_description_prompts",
|
||||||
|
)
|
||||||
|
prediction_headers = (
|
||||||
|
prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
|
||||||
|
)
|
||||||
|
|
||||||
# extend the tables with the files not shown
|
# extend the tables with the files not shown
|
||||||
files_walkthrough_extended = await self.extend_uncovered_files(files_walkthrough)
|
files_walkthrough_extended = await self.extend_uncovered_files(
|
||||||
|
files_walkthrough
|
||||||
|
)
|
||||||
|
|
||||||
# final processing
|
# final processing
|
||||||
self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
|
self.prediction = (
|
||||||
|
prediction_headers + "\n" + "pr_files:\n" + files_walkthrough_extended
|
||||||
|
)
|
||||||
if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix):
|
if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix):
|
||||||
get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}")
|
get_logger().error(
|
||||||
|
f"Error getting valid YAML in large PR handling for describe {self.pr_id}"
|
||||||
|
)
|
||||||
if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix):
|
if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix):
|
||||||
get_logger().debug(f"Using only headers for describe {self.pr_id}")
|
get_logger().debug(f"Using only headers for describe {self.pr_id}")
|
||||||
self.prediction = prediction_headers
|
self.prediction = prediction_headers
|
||||||
@ -321,12 +467,17 @@ class PRDescription:
|
|||||||
prediction = original_prediction
|
prediction = original_prediction
|
||||||
|
|
||||||
# get the original prediction filenames
|
# get the original prediction filenames
|
||||||
original_prediction_loaded = load_yaml(original_prediction, keys_fix_yaml=self.keys_fix)
|
original_prediction_loaded = load_yaml(
|
||||||
|
original_prediction, keys_fix_yaml=self.keys_fix
|
||||||
|
)
|
||||||
if isinstance(original_prediction_loaded, list):
|
if isinstance(original_prediction_loaded, list):
|
||||||
original_prediction_dict = {"pr_files": original_prediction_loaded}
|
original_prediction_dict = {"pr_files": original_prediction_loaded}
|
||||||
else:
|
else:
|
||||||
original_prediction_dict = original_prediction_loaded
|
original_prediction_dict = original_prediction_loaded
|
||||||
filenames_predicted = [file['filename'].strip() for file in original_prediction_dict.get('pr_files', [])]
|
filenames_predicted = [
|
||||||
|
file['filename'].strip()
|
||||||
|
for file in original_prediction_dict.get('pr_files', [])
|
||||||
|
]
|
||||||
|
|
||||||
# extend the prediction with additional files not included in the original prediction
|
# extend the prediction with additional files not included in the original prediction
|
||||||
pr_files = self.git_provider.get_diff_files()
|
pr_files = self.git_provider.get_diff_files()
|
||||||
@ -349,7 +500,9 @@ class PRDescription:
|
|||||||
additional files
|
additional files
|
||||||
"""
|
"""
|
||||||
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
||||||
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}")
|
get_logger().debug(
|
||||||
|
f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_OUTPUT}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
extra_file_yaml = f"""\
|
extra_file_yaml = f"""\
|
||||||
@ -364,10 +517,18 @@ class PRDescription:
|
|||||||
|
|
||||||
# merge the two dictionaries
|
# merge the two dictionaries
|
||||||
if counter_extra_files > 0:
|
if counter_extra_files > 0:
|
||||||
get_logger().info(f"Adding {counter_extra_files} unprocessed extra files to table prediction")
|
get_logger().info(
|
||||||
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
|
f"Adding {counter_extra_files} unprocessed extra files to table prediction"
|
||||||
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
|
)
|
||||||
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
|
prediction_extra_dict = load_yaml(
|
||||||
|
prediction_extra, keys_fix_yaml=self.keys_fix
|
||||||
|
)
|
||||||
|
if isinstance(original_prediction_dict, dict) and isinstance(
|
||||||
|
prediction_extra_dict, dict
|
||||||
|
):
|
||||||
|
original_prediction_dict["pr_files"].extend(
|
||||||
|
prediction_extra_dict["pr_files"]
|
||||||
|
)
|
||||||
new_yaml = yaml.dump(original_prediction_dict)
|
new_yaml = yaml.dump(original_prediction_dict)
|
||||||
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
|
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
|
||||||
prediction = new_yaml
|
prediction = new_yaml
|
||||||
@ -379,11 +540,12 @@ class PRDescription:
|
|||||||
get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}")
|
get_logger().error(f"Error extending uncovered files {self.pr_id}: {e}")
|
||||||
return original_prediction
|
return original_prediction
|
||||||
|
|
||||||
|
|
||||||
async def extend_additional_files(self, remaining_files_list) -> str:
|
async def extend_additional_files(self, remaining_files_list) -> str:
|
||||||
prediction = self.prediction
|
prediction = self.prediction
|
||||||
try:
|
try:
|
||||||
original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix)
|
original_prediction_dict = load_yaml(
|
||||||
|
self.prediction, keys_fix_yaml=self.keys_fix
|
||||||
|
)
|
||||||
prediction_extra = "pr_files:"
|
prediction_extra = "pr_files:"
|
||||||
for file in remaining_files_list:
|
for file in remaining_files_list:
|
||||||
extra_file_yaml = f"""\
|
extra_file_yaml = f"""\
|
||||||
@ -397,10 +559,16 @@ class PRDescription:
|
|||||||
additional files (token-limit)
|
additional files (token-limit)
|
||||||
"""
|
"""
|
||||||
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
||||||
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
|
prediction_extra_dict = load_yaml(
|
||||||
|
prediction_extra, keys_fix_yaml=self.keys_fix
|
||||||
|
)
|
||||||
# merge the two dictionaries
|
# merge the two dictionaries
|
||||||
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
|
if isinstance(original_prediction_dict, dict) and isinstance(
|
||||||
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
|
prediction_extra_dict, dict
|
||||||
|
):
|
||||||
|
original_prediction_dict["pr_files"].extend(
|
||||||
|
prediction_extra_dict["pr_files"]
|
||||||
|
)
|
||||||
new_yaml = yaml.dump(original_prediction_dict)
|
new_yaml = yaml.dump(original_prediction_dict)
|
||||||
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
|
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
|
||||||
prediction = new_yaml
|
prediction = new_yaml
|
||||||
@ -409,7 +577,9 @@ class PRDescription:
|
|||||||
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
|
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
|
||||||
return self.prediction
|
return self.prediction
|
||||||
|
|
||||||
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
|
async def _get_prediction(
|
||||||
|
self, model: str, patches_diff: str, prompt="pr_description_prompt"
|
||||||
|
) -> str:
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = patches_diff # update diff
|
variables["diff"] = patches_diff # update diff
|
||||||
|
|
||||||
@ -417,14 +587,18 @@ class PRDescription:
|
|||||||
set_custom_labels(variables, self.git_provider)
|
set_custom_labels(variables, self.git_provider)
|
||||||
self.variables = variables
|
self.variables = variables
|
||||||
|
|
||||||
system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(self.variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(self.variables)
|
get_settings().get(prompt, {}).get("system", "")
|
||||||
|
).render(self.variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().get(prompt, {}).get("user", "")
|
||||||
|
).render(self.variables)
|
||||||
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
temperature=get_settings().config.temperature,
|
temperature=get_settings().config.temperature,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
user=user_prompt
|
user=user_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -433,7 +607,10 @@ class PRDescription:
|
|||||||
# Load the AI prediction data into a dictionary
|
# Load the AI prediction data into a dictionary
|
||||||
self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix)
|
self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix)
|
||||||
|
|
||||||
if get_settings().pr_description.add_original_user_description and self.user_description:
|
if (
|
||||||
|
get_settings().pr_description.add_original_user_description
|
||||||
|
and self.user_description
|
||||||
|
):
|
||||||
self.data["User Description"] = self.user_description
|
self.data["User Description"] = self.user_description
|
||||||
|
|
||||||
# re-order keys
|
# re-order keys
|
||||||
@ -459,7 +636,11 @@ class PRDescription:
|
|||||||
pr_labels = self.data['labels']
|
pr_labels = self.data['labels']
|
||||||
elif type(self.data['labels']) == str:
|
elif type(self.data['labels']) == str:
|
||||||
pr_labels = self.data['labels'].split(',')
|
pr_labels = self.data['labels'].split(',')
|
||||||
elif 'type' in self.data and self.data['type'] and get_settings().pr_description.publish_labels:
|
elif (
|
||||||
|
'type' in self.data
|
||||||
|
and self.data['type']
|
||||||
|
and get_settings().pr_description.publish_labels
|
||||||
|
):
|
||||||
if type(self.data['type']) == list:
|
if type(self.data['type']) == list:
|
||||||
pr_labels = self.data['type']
|
pr_labels = self.data['type']
|
||||||
elif type(self.data['type']) == str:
|
elif type(self.data['type']) == str:
|
||||||
@ -474,7 +655,9 @@ class PRDescription:
|
|||||||
if label_i in d:
|
if label_i in d:
|
||||||
pr_labels[i] = d[label_i]
|
pr_labels[i] = d[label_i]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
|
get_logger().error(
|
||||||
|
f"Error converting labels to original case {self.pr_id}: {e}"
|
||||||
|
)
|
||||||
return pr_labels
|
return pr_labels
|
||||||
|
|
||||||
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
|
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
|
||||||
@ -482,7 +665,7 @@ class PRDescription:
|
|||||||
|
|
||||||
# Remove the 'PR Title' key from the dictionary
|
# Remove the 'PR Title' key from the dictionary
|
||||||
ai_title = self.data.pop('title', self.vars["title"])
|
ai_title = self.data.pop('title', self.vars["title"])
|
||||||
if (not get_settings().pr_description.generate_ai_title):
|
if not get_settings().pr_description.generate_ai_title:
|
||||||
# Assign the original PR title to the 'title' variable
|
# Assign the original PR title to the 'title' variable
|
||||||
title = self.vars["title"]
|
title = self.vars["title"]
|
||||||
else:
|
else:
|
||||||
@ -514,8 +697,9 @@ class PRDescription:
|
|||||||
pr_file_changes = []
|
pr_file_changes = []
|
||||||
if ai_walkthrough and not re.search(r'<!--\s*pr_agent:walkthrough\s*-->', body):
|
if ai_walkthrough and not re.search(r'<!--\s*pr_agent:walkthrough\s*-->', body):
|
||||||
try:
|
try:
|
||||||
walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(walkthrough_gfm,
|
walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(
|
||||||
self.file_label_dict)
|
walkthrough_gfm, self.file_label_dict
|
||||||
|
)
|
||||||
body = body.replace('pr_agent:walkthrough', walkthrough_gfm)
|
body = body.replace('pr_agent:walkthrough', walkthrough_gfm)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}")
|
get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}")
|
||||||
@ -545,7 +729,7 @@ class PRDescription:
|
|||||||
|
|
||||||
# Remove the 'PR Title' key from the dictionary
|
# Remove the 'PR Title' key from the dictionary
|
||||||
ai_title = self.data.pop('title', self.vars["title"])
|
ai_title = self.data.pop('title', self.vars["title"])
|
||||||
if (not get_settings().pr_description.generate_ai_title):
|
if not get_settings().pr_description.generate_ai_title:
|
||||||
# Assign the original PR title to the 'title' variable
|
# Assign the original PR title to the 'title' variable
|
||||||
title = self.vars["title"]
|
title = self.vars["title"]
|
||||||
else:
|
else:
|
||||||
@ -575,13 +759,20 @@ class PRDescription:
|
|||||||
pr_body += f'- `{filename}`: {description}\n'
|
pr_body += f'- `{filename}`: {description}\n'
|
||||||
if self.git_provider.is_supported("gfm_markdown"):
|
if self.git_provider.is_supported("gfm_markdown"):
|
||||||
pr_body += "</details>\n"
|
pr_body += "</details>\n"
|
||||||
elif 'pr_files' in key.lower() and get_settings().pr_description.enable_semantic_files_types:
|
elif (
|
||||||
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(changes_walkthrough, value)
|
'pr_files' in key.lower()
|
||||||
|
and get_settings().pr_description.enable_semantic_files_types
|
||||||
|
):
|
||||||
|
changes_walkthrough, pr_file_changes = self.process_pr_files_prediction(
|
||||||
|
changes_walkthrough, value
|
||||||
|
)
|
||||||
changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}"
|
changes_walkthrough = f"{PRDescriptionHeader.CHANGES_WALKTHROUGH.value}\n{changes_walkthrough}"
|
||||||
elif key.lower().strip() == 'description':
|
elif key.lower().strip() == 'description':
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
value = ', '.join(v.rstrip() for v in value)
|
value = ', '.join(v.rstrip() for v in value)
|
||||||
value = value.replace('\n-', '\n\n-').strip() # makes the bullet points more readable by adding double space
|
value = value.replace(
|
||||||
|
'\n-', '\n\n-'
|
||||||
|
).strip() # makes the bullet points more readable by adding double space
|
||||||
pr_body += f"{value}\n"
|
pr_body += f"{value}\n"
|
||||||
else:
|
else:
|
||||||
# if the value is a list, join its items by comma
|
# if the value is a list, join its items by comma
|
||||||
@ -591,24 +782,37 @@ class PRDescription:
|
|||||||
if idx < len(self.data) - 1:
|
if idx < len(self.data) - 1:
|
||||||
pr_body += "\n\n___\n\n"
|
pr_body += "\n\n___\n\n"
|
||||||
|
|
||||||
return title, pr_body, changes_walkthrough, pr_file_changes,
|
return (
|
||||||
|
title,
|
||||||
|
pr_body,
|
||||||
|
changes_walkthrough,
|
||||||
|
pr_file_changes,
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_file_labels(self):
|
def _prepare_file_labels(self):
|
||||||
file_label_dict = {}
|
file_label_dict = {}
|
||||||
if (not self.data or not isinstance(self.data, dict) or
|
if (
|
||||||
'pr_files' not in self.data or not self.data['pr_files']):
|
not self.data
|
||||||
|
or not isinstance(self.data, dict)
|
||||||
|
or 'pr_files' not in self.data
|
||||||
|
or not self.data['pr_files']
|
||||||
|
):
|
||||||
return file_label_dict
|
return file_label_dict
|
||||||
for file in self.data['pr_files']:
|
for file in self.data['pr_files']:
|
||||||
try:
|
try:
|
||||||
required_fields = ['changes_title', 'filename', 'label']
|
required_fields = ['changes_title', 'filename', 'label']
|
||||||
if not all(field in file for field in required_fields):
|
if not all(field in file for field in required_fields):
|
||||||
# can happen for example if a YAML generation was interrupted in the middle (no more tokens)
|
# can happen for example if a YAML generation was interrupted in the middle (no more tokens)
|
||||||
get_logger().warning(f"Missing required fields in file label dict {self.pr_id}, skipping file",
|
get_logger().warning(
|
||||||
artifact={"file": file})
|
f"Missing required fields in file label dict {self.pr_id}, skipping file",
|
||||||
|
artifact={"file": file},
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
if not file.get('changes_title'):
|
if not file.get('changes_title'):
|
||||||
get_logger().warning(f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
|
get_logger().warning(
|
||||||
artifact={"file": file})
|
f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
|
||||||
|
artifact={"file": file},
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
filename = file['filename'].replace("'", "`").replace('"', '`')
|
filename = file['filename'].replace("'", "`").replace('"', '`')
|
||||||
changes_summary = file.get('changes_summary', "").strip()
|
changes_summary = file.get('changes_summary', "").strip()
|
||||||
@ -616,7 +820,9 @@ class PRDescription:
|
|||||||
label = file.get('label').strip().lower()
|
label = file.get('label').strip().lower()
|
||||||
if label not in file_label_dict:
|
if label not in file_label_dict:
|
||||||
file_label_dict[label] = []
|
file_label_dict[label] = []
|
||||||
file_label_dict[label].append((filename, changes_title, changes_summary))
|
file_label_dict[label].append(
|
||||||
|
(filename, changes_title, changes_summary)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}")
|
get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}")
|
||||||
pass
|
pass
|
||||||
@ -640,7 +846,9 @@ class PRDescription:
|
|||||||
header = f"相关文件"
|
header = f"相关文件"
|
||||||
delta = 75
|
delta = 75
|
||||||
# header += " " * delta
|
# header += " " * delta
|
||||||
pr_body += f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
|
pr_body += (
|
||||||
|
f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
|
||||||
|
)
|
||||||
pr_body += """<tbody>"""
|
pr_body += """<tbody>"""
|
||||||
for semantic_label in value.keys():
|
for semantic_label in value.keys():
|
||||||
s_label = semantic_label.strip("'").strip('"')
|
s_label = semantic_label.strip("'").strip('"')
|
||||||
@ -651,14 +859,22 @@ class PRDescription:
|
|||||||
pr_body += f"""<td><details><summary>{len(list_tuples)} files</summary><table>"""
|
pr_body += f"""<td><details><summary>{len(list_tuples)} files</summary><table>"""
|
||||||
else:
|
else:
|
||||||
pr_body += f"""<td><table>"""
|
pr_body += f"""<td><table>"""
|
||||||
for filename, file_changes_title, file_change_description in list_tuples:
|
for (
|
||||||
|
filename,
|
||||||
|
file_changes_title,
|
||||||
|
file_change_description,
|
||||||
|
) in list_tuples:
|
||||||
filename = filename.replace("'", "`").rstrip()
|
filename = filename.replace("'", "`").rstrip()
|
||||||
filename_publish = filename.split("/")[-1]
|
filename_publish = filename.split("/")[-1]
|
||||||
if file_changes_title and file_changes_title.strip() != "...":
|
if file_changes_title and file_changes_title.strip() != "...":
|
||||||
file_changes_title_code = f"<code>{file_changes_title}</code>"
|
file_changes_title_code = f"<code>{file_changes_title}</code>"
|
||||||
file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip()
|
file_changes_title_code_br = insert_br_after_x_chars(
|
||||||
|
file_changes_title_code, x=(delta - 5)
|
||||||
|
).strip()
|
||||||
if len(file_changes_title_code_br) < (delta - 5):
|
if len(file_changes_title_code_br) < (delta - 5):
|
||||||
file_changes_title_code_br += " " * ((delta - 5) - len(file_changes_title_code_br))
|
file_changes_title_code_br += " " * (
|
||||||
|
(delta - 5) - len(file_changes_title_code_br)
|
||||||
|
)
|
||||||
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
|
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
|
||||||
else:
|
else:
|
||||||
filename_publish = f"<strong>{filename_publish}</strong>"
|
filename_publish = f"<strong>{filename_publish}</strong>"
|
||||||
@ -679,15 +895,30 @@ class PRDescription:
|
|||||||
link = ""
|
link = ""
|
||||||
if hasattr(self.git_provider, 'get_line_link'):
|
if hasattr(self.git_provider, 'get_line_link'):
|
||||||
filename = filename.strip()
|
filename = filename.strip()
|
||||||
link = self.git_provider.get_line_link(filename, relevant_line_start=-1)
|
link = self.git_provider.get_line_link(
|
||||||
if (not link or not diff_plus_minus) and ('additional files' not in filename.lower()):
|
filename, relevant_line_start=-1
|
||||||
get_logger().warning(f"Error getting line link for '{filename}'")
|
)
|
||||||
|
if (not link or not diff_plus_minus) and (
|
||||||
|
'additional files' not in filename.lower()
|
||||||
|
):
|
||||||
|
get_logger().warning(
|
||||||
|
f"Error getting line link for '{filename}'"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add file data to the PR body
|
# Add file data to the PR body
|
||||||
file_change_description_br = insert_br_after_x_chars(file_change_description, x=(delta - 5))
|
file_change_description_br = insert_br_after_x_chars(
|
||||||
pr_body = self.add_file_data(delta_nbsp, diff_plus_minus, file_change_description_br, filename,
|
file_change_description, x=(delta - 5)
|
||||||
filename_publish, link, pr_body)
|
)
|
||||||
|
pr_body = self.add_file_data(
|
||||||
|
delta_nbsp,
|
||||||
|
diff_plus_minus,
|
||||||
|
file_change_description_br,
|
||||||
|
filename,
|
||||||
|
filename_publish,
|
||||||
|
link,
|
||||||
|
pr_body,
|
||||||
|
)
|
||||||
|
|
||||||
# Close the collapsible file list
|
# Close the collapsible file list
|
||||||
if use_collapsible_file_list:
|
if use_collapsible_file_list:
|
||||||
@ -697,13 +928,22 @@ class PRDescription:
|
|||||||
pr_body += """</tr></tbody></table>"""
|
pr_body += """</tr></tbody></table>"""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error processing pr files to markdown {self.pr_id}: {str(e)}")
|
get_logger().error(
|
||||||
|
f"Error processing pr files to markdown {self.pr_id}: {str(e)}"
|
||||||
|
)
|
||||||
pass
|
pass
|
||||||
return pr_body, pr_comments
|
return pr_body, pr_comments
|
||||||
|
|
||||||
def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link,
|
def add_file_data(
|
||||||
pr_body) -> str:
|
self,
|
||||||
|
delta_nbsp,
|
||||||
|
diff_plus_minus,
|
||||||
|
file_change_description_br,
|
||||||
|
filename,
|
||||||
|
filename_publish,
|
||||||
|
link,
|
||||||
|
pr_body,
|
||||||
|
) -> str:
|
||||||
if not file_change_description_br:
|
if not file_change_description_br:
|
||||||
pr_body += f"""
|
pr_body += f"""
|
||||||
<tr>
|
<tr>
|
||||||
@ -735,6 +975,7 @@ class PRDescription:
|
|||||||
"""
|
"""
|
||||||
return pr_body
|
return pr_body
|
||||||
|
|
||||||
|
|
||||||
def count_chars_without_html(string):
|
def count_chars_without_html(string):
|
||||||
if '<' not in string:
|
if '<' not in string:
|
||||||
return len(string)
|
return len(string)
|
||||||
|
|||||||
@ -16,8 +16,12 @@ from utils.pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRGenerateLabels:
|
class PRGenerateLabels:
|
||||||
def __init__(self, pr_url: str, args: list = None,
|
def __init__(
|
||||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
args: list = None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
||||||
corresponding to the PR using an AI model.
|
corresponding to the PR using an AI model.
|
||||||
@ -93,7 +97,9 @@ class PRGenerateLabels:
|
|||||||
elif pr_labels:
|
elif pr_labels:
|
||||||
value = ', '.join(v for v in pr_labels)
|
value = ', '.join(v for v in pr_labels)
|
||||||
pr_labels_text = f"## PR Labels:\n{value}\n"
|
pr_labels_text = f"## PR Labels:\n{value}\n"
|
||||||
self.git_provider.publish_comment(pr_labels_text, is_temporary=False)
|
self.git_provider.publish_comment(
|
||||||
|
pr_labels_text, is_temporary=False
|
||||||
|
)
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error generating PR labels {self.pr_id}: {e}")
|
get_logger().error(f"Error generating PR labels {self.pr_id}: {e}")
|
||||||
@ -137,14 +143,18 @@ class PRGenerateLabels:
|
|||||||
set_custom_labels(variables, self.git_provider)
|
set_custom_labels(variables, self.git_provider)
|
||||||
self.variables = variables
|
self.variables = variables
|
||||||
|
|
||||||
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(self.variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(self.variables)
|
get_settings().pr_custom_labels_prompt.system
|
||||||
|
).render(self.variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().pr_custom_labels_prompt.user
|
||||||
|
).render(self.variables)
|
||||||
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
temperature=get_settings().config.temperature,
|
temperature=get_settings().config.temperature,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
user=user_prompt
|
user=user_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -153,8 +163,6 @@ class PRGenerateLabels:
|
|||||||
# Load the AI prediction data into a dictionary
|
# Load the AI prediction data into a dictionary
|
||||||
self.data = load_yaml(self.prediction.strip())
|
self.data = load_yaml(self.prediction.strip())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_labels(self) -> List[str]:
|
def _prepare_labels(self) -> List[str]:
|
||||||
pr_types = []
|
pr_types = []
|
||||||
|
|
||||||
@ -174,6 +182,8 @@ class PRGenerateLabels:
|
|||||||
if label_i in d:
|
if label_i in d:
|
||||||
pr_types[i] = d[label_i]
|
pr_types[i] = d[label_i]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
|
get_logger().error(
|
||||||
|
f"Error converting labels to original case {self.pr_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return pr_types
|
return pr_types
|
||||||
|
|||||||
@ -12,7 +12,11 @@ from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
|
|||||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||||
from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
|
from utils.pr_agent.algo.utils import ModelType, clip_tokens, load_yaml, get_max_tokens
|
||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
from utils.pr_agent.git_providers import BitbucketServerProvider, GithubProvider, get_git_provider_with_context
|
from utils.pr_agent.git_providers import (
|
||||||
|
BitbucketServerProvider,
|
||||||
|
GithubProvider,
|
||||||
|
get_git_provider_with_context,
|
||||||
|
)
|
||||||
from utils.pr_agent.log import get_logger
|
from utils.pr_agent.log import get_logger
|
||||||
|
|
||||||
|
|
||||||
@ -29,31 +33,50 @@ def extract_header(snippet):
|
|||||||
res = f"#{highest_header.lower().replace(' ', '-')}"
|
res = f"#{highest_header.lower().replace(' ', '-')}"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class PRHelpMessage:
|
class PRHelpMessage:
|
||||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, return_as_string=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
args=None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
return_as_string=False,
|
||||||
|
):
|
||||||
self.git_provider = get_git_provider_with_context(pr_url)
|
self.git_provider = get_git_provider_with_context(pr_url)
|
||||||
self.ai_handler = ai_handler()
|
self.ai_handler = ai_handler()
|
||||||
self.question_str = self.parse_args(args)
|
self.question_str = self.parse_args(args)
|
||||||
self.return_as_string = return_as_string
|
self.return_as_string = return_as_string
|
||||||
self.num_retrieved_snippets = get_settings().get('pr_help.num_retrieved_snippets', 5)
|
self.num_retrieved_snippets = get_settings().get(
|
||||||
|
'pr_help.num_retrieved_snippets', 5
|
||||||
|
)
|
||||||
if self.question_str:
|
if self.question_str:
|
||||||
self.vars = {
|
self.vars = {
|
||||||
"question": self.question_str,
|
"question": self.question_str,
|
||||||
"snippets": "",
|
"snippets": "",
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(None,
|
self.token_handler = TokenHandler(
|
||||||
self.vars,
|
None,
|
||||||
get_settings().pr_help_prompts.system,
|
self.vars,
|
||||||
get_settings().pr_help_prompts.user)
|
get_settings().pr_help_prompts.system,
|
||||||
|
get_settings().pr_help_prompts.user,
|
||||||
|
)
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str):
|
async def _prepare_prediction(self, model: str):
|
||||||
try:
|
try:
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(get_settings().pr_help_prompts.system).render(variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().pr_help_prompts.user).render(variables)
|
get_settings().pr_help_prompts.system
|
||||||
|
).render(variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().pr_help_prompts.user
|
||||||
|
).render(variables)
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
model=model,
|
||||||
|
temperature=get_settings().config.temperature,
|
||||||
|
system=system_prompt,
|
||||||
|
user=user_prompt,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error while preparing prediction: {e}")
|
get_logger().error(f"Error while preparing prediction: {e}")
|
||||||
@ -81,7 +104,7 @@ class PRHelpMessage:
|
|||||||
'.': '',
|
'.': '',
|
||||||
'?': '',
|
'?': '',
|
||||||
'!': '',
|
'!': '',
|
||||||
' ': '-'
|
' ': '-',
|
||||||
}
|
}
|
||||||
|
|
||||||
# Compile regex pattern for characters to remove
|
# Compile regex pattern for characters to remove
|
||||||
@ -90,37 +113,69 @@ class PRHelpMessage:
|
|||||||
# Perform replacements in a single pass and convert to lowercase
|
# Perform replacements in a single pass and convert to lowercase
|
||||||
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
|
return pattern.sub(lambda m: replacements[m.group()], cleaned).lower()
|
||||||
except Exception:
|
except Exception:
|
||||||
get_logger().exception(f"Error while formatting markdown header", artifacts={'header': header})
|
get_logger().exception(
|
||||||
|
f"Error while formatting markdown header", artifacts={'header': header}
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
if self.question_str:
|
if self.question_str:
|
||||||
get_logger().info(f'Answering a PR question about the PR {self.git_provider.pr_url} ')
|
get_logger().info(
|
||||||
|
f'Answering a PR question about the PR {self.git_provider.pr_url} '
|
||||||
|
)
|
||||||
|
|
||||||
if not get_settings().get('openai.key'):
|
if not get_settings().get('openai.key'):
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment(
|
self.git_provider.publish_comment(
|
||||||
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
|
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().error("The `Help` tool chat feature requires an OpenAI API key for calculating embeddings")
|
get_logger().error(
|
||||||
|
"The `Help` tool chat feature requires an OpenAI API key for calculating embeddings"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# current path
|
# current path
|
||||||
docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs'
|
docs_path = Path(__file__).parent.parent.parent / 'docs' / 'docs'
|
||||||
# get all the 'md' files inside docs_path and its subdirectories
|
# get all the 'md' files inside docs_path and its subdirectories
|
||||||
md_files = list(docs_path.glob('**/*.md'))
|
md_files = list(docs_path.glob('**/*.md'))
|
||||||
folders_to_exclude = ['/finetuning_benchmark/']
|
folders_to_exclude = ['/finetuning_benchmark/']
|
||||||
files_to_exclude = {'EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md'}
|
files_to_exclude = {
|
||||||
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
|
'EXAMPLE_BEST_PRACTICE.md',
|
||||||
|
'compression_strategy.md',
|
||||||
|
'/docs/overview/index.md',
|
||||||
|
}
|
||||||
|
md_files = [
|
||||||
|
file
|
||||||
|
for file in md_files
|
||||||
|
if not any(folder in str(file) for folder in folders_to_exclude)
|
||||||
|
and not any(
|
||||||
|
file.name == file_to_exclude
|
||||||
|
for file_to_exclude in files_to_exclude
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# sort the 'md_files' so that 'priority_files' will be at the top
|
# sort the 'md_files' so that 'priority_files' will be at the top
|
||||||
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
|
priority_files_strings = [
|
||||||
'tools/improve.md', '/faq']
|
'/docs/index.md',
|
||||||
md_files_priority = [file for file in md_files if
|
'/usage-guide',
|
||||||
any(priority_string in str(file) for priority_string in priority_files_strings)]
|
'tools/describe.md',
|
||||||
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
|
'tools/review.md',
|
||||||
|
'tools/improve.md',
|
||||||
|
'/faq',
|
||||||
|
]
|
||||||
|
md_files_priority = [
|
||||||
|
file
|
||||||
|
for file in md_files
|
||||||
|
if any(
|
||||||
|
priority_string in str(file)
|
||||||
|
for priority_string in priority_files_strings
|
||||||
|
)
|
||||||
|
]
|
||||||
|
md_files_not_priority = [
|
||||||
|
file for file in md_files if file not in md_files_priority
|
||||||
|
]
|
||||||
md_files = md_files_priority + md_files_not_priority
|
md_files = md_files_priority + md_files_not_priority
|
||||||
|
|
||||||
docs_prompt = ""
|
docs_prompt = ""
|
||||||
@ -132,24 +187,36 @@ class PRHelpMessage:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error while reading the file {file}: {e}")
|
get_logger().error(f"Error while reading the file {file}: {e}")
|
||||||
token_count = self.token_handler.count_tokens(docs_prompt)
|
token_count = self.token_handler.count_tokens(docs_prompt)
|
||||||
get_logger().debug(f"Token count of full documentation website: {token_count}")
|
get_logger().debug(
|
||||||
|
f"Token count of full documentation website: {token_count}"
|
||||||
|
)
|
||||||
|
|
||||||
model = get_settings().config.model
|
model = get_settings().config.model
|
||||||
if model in MAX_TOKENS:
|
if model in MAX_TOKENS:
|
||||||
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
max_tokens_full = MAX_TOKENS[
|
||||||
|
model
|
||||||
|
] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
|
||||||
else:
|
else:
|
||||||
max_tokens_full = get_max_tokens(model)
|
max_tokens_full = get_max_tokens(model)
|
||||||
delta_output = 2000
|
delta_output = 2000
|
||||||
if token_count > max_tokens_full - delta_output:
|
if token_count > max_tokens_full - delta_output:
|
||||||
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
|
get_logger().info(
|
||||||
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
|
f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message."
|
||||||
|
)
|
||||||
|
docs_prompt = clip_tokens(
|
||||||
|
docs_prompt, max_tokens_full - delta_output
|
||||||
|
)
|
||||||
self.vars['snippets'] = docs_prompt.strip()
|
self.vars['snippets'] = docs_prompt.strip()
|
||||||
|
|
||||||
# run the AI model
|
# run the AI model
|
||||||
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
response = await retry_with_fallback_models(
|
||||||
|
self._prepare_prediction, model_type=ModelType.REGULAR
|
||||||
|
)
|
||||||
response_yaml = load_yaml(response)
|
response_yaml = load_yaml(response)
|
||||||
if isinstance(response_yaml, str):
|
if isinstance(response_yaml, str):
|
||||||
get_logger().warning(f"failing to parse response: {response_yaml}, publishing the response as is")
|
get_logger().warning(
|
||||||
|
f"failing to parse response: {response_yaml}, publishing the response as is"
|
||||||
|
)
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
answer_str = f"### Question: \n{self.question_str}\n\n"
|
answer_str = f"### Question: \n{self.question_str}\n\n"
|
||||||
answer_str += f"### Answer:\n\n"
|
answer_str += f"### Answer:\n\n"
|
||||||
@ -160,7 +227,9 @@ class PRHelpMessage:
|
|||||||
relevant_sections = response_yaml.get('relevant_sections')
|
relevant_sections = response_yaml.get('relevant_sections')
|
||||||
|
|
||||||
if not relevant_sections:
|
if not relevant_sections:
|
||||||
get_logger().info(f"Could not find relevant answer for the question: {self.question_str}")
|
get_logger().info(
|
||||||
|
f"Could not find relevant answer for the question: {self.question_str}"
|
||||||
|
)
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
answer_str = f"### Question: \n{self.question_str}\n\n"
|
answer_str = f"### Question: \n{self.question_str}\n\n"
|
||||||
answer_str += f"### Answer:\n\n"
|
answer_str += f"### Answer:\n\n"
|
||||||
@ -178,29 +247,38 @@ class PRHelpMessage:
|
|||||||
for section in relevant_sections:
|
for section in relevant_sections:
|
||||||
file = section.get('file_name').strip().removesuffix('.md')
|
file = section.get('file_name').strip().removesuffix('.md')
|
||||||
if str(section['relevant_section_header_string']).strip():
|
if str(section['relevant_section_header_string']).strip():
|
||||||
markdown_header = self.format_markdown_header(section['relevant_section_header_string'])
|
markdown_header = self.format_markdown_header(
|
||||||
|
section['relevant_section_header_string']
|
||||||
|
)
|
||||||
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
|
answer_str += f"> - {base_path}{file}#{markdown_header}\n"
|
||||||
else:
|
else:
|
||||||
answer_str += f"> - {base_path}{file}\n"
|
answer_str += f"> - {base_path}{file}\n"
|
||||||
|
|
||||||
|
|
||||||
# publish the answer
|
# publish the answer
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment(answer_str)
|
self.git_provider.publish_comment(answer_str)
|
||||||
else:
|
else:
|
||||||
get_logger().info(f"Answer:\n{answer_str}")
|
get_logger().info(f"Answer:\n{answer_str}")
|
||||||
else:
|
else:
|
||||||
if not isinstance(self.git_provider, BitbucketServerProvider) and not self.git_provider.is_supported("gfm_markdown"):
|
if not isinstance(
|
||||||
|
self.git_provider, BitbucketServerProvider
|
||||||
|
) and not self.git_provider.is_supported("gfm_markdown"):
|
||||||
self.git_provider.publish_comment(
|
self.git_provider.publish_comment(
|
||||||
"The `Help` tool requires gfm markdown, which is not supported by your code platform.")
|
"The `Help` tool requires gfm markdown, which is not supported by your code platform."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
get_logger().info('Getting PR Help Message...')
|
get_logger().info('Getting PR Help Message...')
|
||||||
relevant_configs = {'pr_help': dict(get_settings().pr_help),
|
relevant_configs = {
|
||||||
'config': dict(get_settings().config)}
|
'pr_help': dict(get_settings().pr_help),
|
||||||
|
'config': dict(get_settings().config),
|
||||||
|
}
|
||||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||||
pr_comment = "## PR Agent Walkthrough 🤖\n\n"
|
pr_comment = "## PR Agent Walkthrough 🤖\n\n"
|
||||||
pr_comment += "Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."""
|
pr_comment += (
|
||||||
|
"Welcome to the PR Agent, an AI-powered tool for automated pull request analysis, feedback, suggestions and more."
|
||||||
|
""
|
||||||
|
)
|
||||||
pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n"
|
pr_comment += "\n\nHere is a list of tools you can use to interact with the PR Agent:\n"
|
||||||
base_path = "https://pr-agent-docs.codium.ai/tools"
|
base_path = "https://pr-agent-docs.codium.ai/tools"
|
||||||
|
|
||||||
@ -211,32 +289,58 @@ class PRHelpMessage:
|
|||||||
tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)")
|
tool_names.append(f"[UPDATE CHANGELOG]({base_path}/update_changelog/)")
|
||||||
tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎")
|
tool_names.append(f"[ADD DOCS]({base_path}/documentation/) 💎")
|
||||||
tool_names.append(f"[TEST]({base_path}/test/) 💎")
|
tool_names.append(f"[TEST]({base_path}/test/) 💎")
|
||||||
tool_names.append(f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎")
|
tool_names.append(
|
||||||
|
f"[IMPROVE COMPONENT]({base_path}/improve_component/) 💎"
|
||||||
|
)
|
||||||
tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎")
|
tool_names.append(f"[ANALYZE]({base_path}/analyze/) 💎")
|
||||||
tool_names.append(f"[ASK]({base_path}/ask/)")
|
tool_names.append(f"[ASK]({base_path}/ask/)")
|
||||||
tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)")
|
tool_names.append(f"[SIMILAR ISSUE]({base_path}/similar_issues/)")
|
||||||
tool_names.append(f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎")
|
tool_names.append(
|
||||||
|
f"[GENERATE CUSTOM LABELS]({base_path}/custom_labels/) 💎"
|
||||||
|
)
|
||||||
tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎")
|
tool_names.append(f"[CI FEEDBACK]({base_path}/ci_feedback/) 💎")
|
||||||
tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎")
|
tool_names.append(f"[CUSTOM PROMPT]({base_path}/custom_prompt/) 💎")
|
||||||
tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎")
|
tool_names.append(f"[IMPLEMENT]({base_path}/implement/) 💎")
|
||||||
|
|
||||||
descriptions = []
|
descriptions = []
|
||||||
descriptions.append("Generates PR description - title, type, summary, code walkthrough and labels")
|
descriptions.append(
|
||||||
descriptions.append("Adjustable feedback about the PR, possible issues, security concerns, review effort and more")
|
"Generates PR description - title, type, summary, code walkthrough and labels"
|
||||||
|
)
|
||||||
|
descriptions.append(
|
||||||
|
"Adjustable feedback about the PR, possible issues, security concerns, review effort and more"
|
||||||
|
)
|
||||||
descriptions.append("Code suggestions for improving the PR")
|
descriptions.append("Code suggestions for improving the PR")
|
||||||
descriptions.append("Automatically updates the changelog")
|
descriptions.append("Automatically updates the changelog")
|
||||||
descriptions.append("Generates documentation to methods/functions/classes that changed in the PR")
|
descriptions.append(
|
||||||
descriptions.append("Generates unit tests for a specific component, based on the PR code change")
|
"Generates documentation to methods/functions/classes that changed in the PR"
|
||||||
descriptions.append("Code suggestions for a specific component that changed in the PR")
|
)
|
||||||
descriptions.append("Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component")
|
descriptions.append(
|
||||||
|
"Generates unit tests for a specific component, based on the PR code change"
|
||||||
|
)
|
||||||
|
descriptions.append(
|
||||||
|
"Code suggestions for a specific component that changed in the PR"
|
||||||
|
)
|
||||||
|
descriptions.append(
|
||||||
|
"Identifies code components that changed in the PR, and enables to interactively generate tests, docs, and code suggestions for each component"
|
||||||
|
)
|
||||||
descriptions.append("Answering free-text questions about the PR")
|
descriptions.append("Answering free-text questions about the PR")
|
||||||
descriptions.append("Automatically retrieves and presents similar issues")
|
descriptions.append(
|
||||||
descriptions.append("Generates custom labels for the PR, based on specific guidelines defined by the user")
|
"Automatically retrieves and presents similar issues"
|
||||||
descriptions.append("Generates feedback and analysis for a failed CI job")
|
)
|
||||||
descriptions.append("Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user")
|
descriptions.append(
|
||||||
descriptions.append("Generates implementation code from review suggestions")
|
"Generates custom labels for the PR, based on specific guidelines defined by the user"
|
||||||
|
)
|
||||||
|
descriptions.append(
|
||||||
|
"Generates feedback and analysis for a failed CI job"
|
||||||
|
)
|
||||||
|
descriptions.append(
|
||||||
|
"Generates custom suggestions for improving the PR code, derived only from a specific guidelines prompt defined by the user"
|
||||||
|
)
|
||||||
|
descriptions.append(
|
||||||
|
"Generates implementation code from review suggestions"
|
||||||
|
)
|
||||||
|
|
||||||
commands =[]
|
commands = []
|
||||||
commands.append("`/describe`")
|
commands.append("`/describe`")
|
||||||
commands.append("`/review`")
|
commands.append("`/review`")
|
||||||
commands.append("`/improve`")
|
commands.append("`/improve`")
|
||||||
@ -271,7 +375,9 @@ class PRHelpMessage:
|
|||||||
checkbox_list.append("[*]")
|
checkbox_list.append("[*]")
|
||||||
checkbox_list.append("[*]")
|
checkbox_list.append("[*]")
|
||||||
|
|
||||||
if isinstance(self.git_provider, GithubProvider) and not get_settings().config.get('disable_checkboxes', False):
|
if isinstance(
|
||||||
|
self.git_provider, GithubProvider
|
||||||
|
) and not get_settings().config.get('disable_checkboxes', False):
|
||||||
pr_comment += f"<table><tr align='left'><th align='left'>Tool</th><th align='left'>Description</th><th align='left'>Trigger Interactively :gem:</th></tr>"
|
pr_comment += f"<table><tr align='left'><th align='left'>Tool</th><th align='left'>Description</th><th align='left'>Trigger Interactively :gem:</th></tr>"
|
||||||
for i in range(len(tool_names)):
|
for i in range(len(tool_names)):
|
||||||
pr_comment += f"\n<tr><td align='left'>\n\n<strong>{tool_names[i]}</strong></td>\n<td>{descriptions[i]}</td>\n<td>\n\n{checkbox_list[i]}\n</td></tr>"
|
pr_comment += f"\n<tr><td align='left'>\n\n<strong>{tool_names[i]}</strong></td>\n<td>{descriptions[i]}</td>\n<td>\n\n{checkbox_list[i]}\n</td></tr>"
|
||||||
|
|||||||
@ -5,8 +5,7 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
|
|
||||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
from utils.pr_agent.algo.git_patch_processing import (
|
from utils.pr_agent.algo.git_patch_processing import extract_hunk_lines_from_patch
|
||||||
extract_hunk_lines_from_patch)
|
|
||||||
from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
|
from utils.pr_agent.algo.pr_processing import retry_with_fallback_models
|
||||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||||
from utils.pr_agent.algo.utils import ModelType
|
from utils.pr_agent.algo.utils import ModelType
|
||||||
@ -17,7 +16,12 @@ from utils.pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PR_LineQuestions:
|
class PR_LineQuestions:
|
||||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
def __init__(
|
||||||
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
args=None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
):
|
||||||
self.question_str = self.parse_args(args)
|
self.question_str = self.parse_args(args)
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_pr_language = get_main_pr_language(
|
self.main_pr_language = get_main_pr_language(
|
||||||
@ -34,10 +38,12 @@ class PR_LineQuestions:
|
|||||||
"full_hunk": "",
|
"full_hunk": "",
|
||||||
"selected_lines": "",
|
"selected_lines": "",
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(
|
||||||
self.vars,
|
self.git_provider.pr,
|
||||||
get_settings().pr_line_questions_prompt.system,
|
self.vars,
|
||||||
get_settings().pr_line_questions_prompt.user)
|
get_settings().pr_line_questions_prompt.system,
|
||||||
|
get_settings().pr_line_questions_prompt.user,
|
||||||
|
)
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
|
||||||
@ -48,7 +54,6 @@ class PR_LineQuestions:
|
|||||||
question_str = ""
|
question_str = ""
|
||||||
return question_str
|
return question_str
|
||||||
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
get_logger().info('Answering a PR lines question...')
|
get_logger().info('Answering a PR lines question...')
|
||||||
# if get_settings().config.publish_output:
|
# if get_settings().config.publish_output:
|
||||||
@ -62,22 +67,27 @@ class PR_LineQuestions:
|
|||||||
file_name = get_settings().get('file_name', '')
|
file_name = get_settings().get('file_name', '')
|
||||||
comment_id = get_settings().get('comment_id', '')
|
comment_id = get_settings().get('comment_id', '')
|
||||||
if ask_diff:
|
if ask_diff:
|
||||||
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(ask_diff,
|
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(
|
||||||
file_name,
|
ask_diff, file_name, line_start=line_start, line_end=line_end, side=side
|
||||||
line_start=line_start,
|
)
|
||||||
line_end=line_end,
|
|
||||||
side=side
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
diff_files = self.git_provider.get_diff_files()
|
diff_files = self.git_provider.get_diff_files()
|
||||||
for file in diff_files:
|
for file in diff_files:
|
||||||
if file.filename == file_name:
|
if file.filename == file_name:
|
||||||
self.patch_with_lines, self.selected_lines = extract_hunk_lines_from_patch(file.patch, file.filename,
|
(
|
||||||
line_start=line_start,
|
self.patch_with_lines,
|
||||||
line_end=line_end,
|
self.selected_lines,
|
||||||
side=side)
|
) = extract_hunk_lines_from_patch(
|
||||||
|
file.patch,
|
||||||
|
file.filename,
|
||||||
|
line_start=line_start,
|
||||||
|
line_end=line_end,
|
||||||
|
side=side,
|
||||||
|
)
|
||||||
if self.patch_with_lines:
|
if self.patch_with_lines:
|
||||||
model_answer = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK)
|
model_answer = await retry_with_fallback_models(
|
||||||
|
self._get_prediction, model_type=ModelType.WEAK
|
||||||
|
)
|
||||||
# sanitize the answer so that no line will start with "/"
|
# sanitize the answer so that no line will start with "/"
|
||||||
model_answer_sanitized = model_answer.strip().replace("\n/", "\n /")
|
model_answer_sanitized = model_answer.strip().replace("\n/", "\n /")
|
||||||
if model_answer_sanitized.startswith("/"):
|
if model_answer_sanitized.startswith("/"):
|
||||||
@ -85,7 +95,9 @@ class PR_LineQuestions:
|
|||||||
|
|
||||||
get_logger().info('Preparing answer...')
|
get_logger().info('Preparing answer...')
|
||||||
if comment_id:
|
if comment_id:
|
||||||
self.git_provider.reply_to_comment_from_comment_id(comment_id, model_answer_sanitized)
|
self.git_provider.reply_to_comment_from_comment_id(
|
||||||
|
comment_id, model_answer_sanitized
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.git_provider.publish_comment(model_answer_sanitized)
|
self.git_provider.publish_comment(model_answer_sanitized)
|
||||||
|
|
||||||
@ -96,8 +108,12 @@ class PR_LineQuestions:
|
|||||||
variables["full_hunk"] = self.patch_with_lines # update diff
|
variables["full_hunk"] = self.patch_with_lines # update diff
|
||||||
variables["selected_lines"] = self.selected_lines
|
variables["selected_lines"] = self.selected_lines
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(get_settings().pr_line_questions_prompt.system).render(variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().pr_line_questions_prompt.user).render(variables)
|
get_settings().pr_line_questions_prompt.system
|
||||||
|
).render(variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().pr_line_questions_prompt.user
|
||||||
|
).render(variables)
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
# get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
# get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
# get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
# get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
||||||
@ -105,5 +121,9 @@ class PR_LineQuestions:
|
|||||||
print(f"\nUser prompt:\n{user_prompt}")
|
print(f"\nUser prompt:\n{user_prompt}")
|
||||||
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
model=model,
|
||||||
|
temperature=get_settings().config.temperature,
|
||||||
|
system=system_prompt,
|
||||||
|
user=user_prompt,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -16,7 +16,12 @@ from utils.pr_agent.servers.help import HelpMessage
|
|||||||
|
|
||||||
|
|
||||||
class PRQuestions:
|
class PRQuestions:
|
||||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
def __init__(
|
||||||
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
args=None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
):
|
||||||
question_str = self.parse_args(args)
|
question_str = self.parse_args(args)
|
||||||
self.pr_url = pr_url
|
self.pr_url = pr_url
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
@ -36,10 +41,12 @@ class PRQuestions:
|
|||||||
"questions": self.question_str,
|
"questions": self.question_str,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(
|
||||||
self.vars,
|
self.git_provider.pr,
|
||||||
get_settings().pr_questions_prompt.system,
|
self.vars,
|
||||||
get_settings().pr_questions_prompt.user)
|
get_settings().pr_questions_prompt.system,
|
||||||
|
get_settings().pr_questions_prompt.user,
|
||||||
|
)
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
|
||||||
@ -52,8 +59,10 @@ class PRQuestions:
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
get_logger().info(f'Answering a PR question about the PR {self.pr_url} ')
|
get_logger().info(f'Answering a PR question about the PR {self.pr_url} ')
|
||||||
relevant_configs = {'pr_questions': dict(get_settings().pr_questions),
|
relevant_configs = {
|
||||||
'config': dict(get_settings().config)}
|
'pr_questions': dict(get_settings().pr_questions),
|
||||||
|
'config': dict(get_settings().config),
|
||||||
|
}
|
||||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("思考回答中...", is_temporary=True)
|
self.git_provider.publish_comment("思考回答中...", is_temporary=True)
|
||||||
@ -63,12 +72,17 @@ class PRQuestions:
|
|||||||
if img_path:
|
if img_path:
|
||||||
get_logger().debug(f"Image path identified", artifact=img_path)
|
get_logger().debug(f"Image path identified", artifact=img_path)
|
||||||
|
|
||||||
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
|
await retry_with_fallback_models(
|
||||||
|
self._prepare_prediction, model_type=ModelType.WEAK
|
||||||
|
)
|
||||||
|
|
||||||
pr_comment = self._prepare_pr_answer()
|
pr_comment = self._prepare_pr_answer()
|
||||||
get_logger().debug(f"PR output", artifact=pr_comment)
|
get_logger().debug(f"PR output", artifact=pr_comment)
|
||||||
|
|
||||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_questions.enable_help_text:
|
if (
|
||||||
|
self.git_provider.is_supported("gfm_markdown")
|
||||||
|
and get_settings().pr_questions.enable_help_text
|
||||||
|
):
|
||||||
pr_comment += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
pr_comment += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
||||||
pr_comment += HelpMessage.get_ask_usage_guide()
|
pr_comment += HelpMessage.get_ask_usage_guide()
|
||||||
pr_comment += "\n</details>\n"
|
pr_comment += "\n</details>\n"
|
||||||
@ -85,7 +99,9 @@ class PRQuestions:
|
|||||||
# /ask question ... > 
|
# /ask question ... > 
|
||||||
img_path = self.question_str.split('![image]')[1].strip().strip('()')
|
img_path = self.question_str.split('![image]')[1].strip().strip('()')
|
||||||
self.vars['img_path'] = img_path
|
self.vars['img_path'] = img_path
|
||||||
elif 'https://' in self.question_str and ('.png' in self.question_str or 'jpg' in self.question_str): # direct image link
|
elif 'https://' in self.question_str and (
|
||||||
|
'.png' in self.question_str or 'jpg' in self.question_str
|
||||||
|
): # direct image link
|
||||||
# include https:// in the image path
|
# include https:// in the image path
|
||||||
img_path = 'https://' + self.question_str.split('https://')[1]
|
img_path = 'https://' + self.question_str.split('https://')[1]
|
||||||
self.vars['img_path'] = img_path
|
self.vars['img_path'] = img_path
|
||||||
@ -104,16 +120,28 @@ class PRQuestions:
|
|||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(get_settings().pr_questions_prompt.system).render(variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().pr_questions_prompt.user).render(variables)
|
get_settings().pr_questions_prompt.system
|
||||||
|
).render(variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().pr_questions_prompt.user
|
||||||
|
).render(variables)
|
||||||
if 'img_path' in variables:
|
if 'img_path' in variables:
|
||||||
img_path = self.vars['img_path']
|
img_path = self.vars['img_path']
|
||||||
response, finish_reason = await (self.ai_handler.chat_completion
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
(model=model, temperature=get_settings().config.temperature,
|
model=model,
|
||||||
system=system_prompt, user=user_prompt, img_path=img_path))
|
temperature=get_settings().config.temperature,
|
||||||
|
system=system_prompt,
|
||||||
|
user=user_prompt,
|
||||||
|
img_path=img_path,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
|
model=model,
|
||||||
|
temperature=get_settings().config.temperature,
|
||||||
|
system=system_prompt,
|
||||||
|
user=user_prompt,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _prepare_pr_answer(self) -> str:
|
def _prepare_pr_answer(self) -> str:
|
||||||
@ -123,9 +151,13 @@ class PRQuestions:
|
|||||||
if model_answer_sanitized.startswith("/"):
|
if model_answer_sanitized.startswith("/"):
|
||||||
model_answer_sanitized = " " + model_answer_sanitized
|
model_answer_sanitized = " " + model_answer_sanitized
|
||||||
if model_answer_sanitized != model_answer:
|
if model_answer_sanitized != model_answer:
|
||||||
get_logger().debug(f"Sanitized model answer",
|
get_logger().debug(
|
||||||
artifact={"model_answer": model_answer, "sanitized_answer": model_answer_sanitized})
|
f"Sanitized model answer",
|
||||||
|
artifact={
|
||||||
|
"model_answer": model_answer,
|
||||||
|
"sanitized_answer": model_answer_sanitized,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
answer_str = f"### **Ask**❓\n{self.question_str}\n\n"
|
answer_str = f"### **Ask**❓\n{self.question_str}\n\n"
|
||||||
answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n"
|
answer_str += f"### **Answer:**\n{model_answer_sanitized}\n\n"
|
||||||
|
|||||||
@ -7,21 +7,29 @@ from jinja2 import Environment, StrictUndefined
|
|||||||
|
|
||||||
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from utils.pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
from utils.pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
from utils.pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files,
|
from utils.pr_agent.algo.pr_processing import (
|
||||||
get_pr_diff,
|
add_ai_metadata_to_diff_files,
|
||||||
retry_with_fallback_models)
|
get_pr_diff,
|
||||||
|
retry_with_fallback_models,
|
||||||
|
)
|
||||||
from utils.pr_agent.algo.token_handler import TokenHandler
|
from utils.pr_agent.algo.token_handler import TokenHandler
|
||||||
from utils.pr_agent.algo.utils import (ModelType, PRReviewHeader,
|
from utils.pr_agent.algo.utils import (
|
||||||
convert_to_markdown_v2, github_action_output,
|
ModelType,
|
||||||
load_yaml, show_relevant_configurations)
|
PRReviewHeader,
|
||||||
|
convert_to_markdown_v2,
|
||||||
|
github_action_output,
|
||||||
|
load_yaml,
|
||||||
|
show_relevant_configurations,
|
||||||
|
)
|
||||||
from utils.pr_agent.config_loader import get_settings
|
from utils.pr_agent.config_loader import get_settings
|
||||||
from utils.pr_agent.git_providers import (get_git_provider_with_context)
|
from utils.pr_agent.git_providers import get_git_provider_with_context
|
||||||
from utils.pr_agent.git_providers.git_provider import (IncrementalPR,
|
from utils.pr_agent.git_providers.git_provider import (
|
||||||
get_main_pr_language)
|
IncrementalPR,
|
||||||
|
get_main_pr_language,
|
||||||
|
)
|
||||||
from utils.pr_agent.log import get_logger
|
from utils.pr_agent.log import get_logger
|
||||||
from utils.pr_agent.servers.help import HelpMessage
|
from utils.pr_agent.servers.help import HelpMessage
|
||||||
from utils.pr_agent.tools.ticket_pr_compliance_check import (
|
from utils.pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets
|
||||||
extract_and_cache_pr_tickets)
|
|
||||||
|
|
||||||
|
|
||||||
class PRReviewer:
|
class PRReviewer:
|
||||||
@ -29,8 +37,14 @@ class PRReviewer:
|
|||||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
|
def __init__(
|
||||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
is_answer: bool = False,
|
||||||
|
is_auto: bool = False,
|
||||||
|
args: list = None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
||||||
|
|
||||||
@ -55,16 +69,23 @@ class PRReviewer:
|
|||||||
self.is_auto = is_auto
|
self.is_auto = is_auto
|
||||||
|
|
||||||
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
||||||
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
|
raise Exception(
|
||||||
|
f"Answer mode is not supported for {get_settings().config.git_provider} for now"
|
||||||
|
)
|
||||||
self.ai_handler = ai_handler()
|
self.ai_handler = ai_handler()
|
||||||
self.ai_handler.main_pr_language = self.main_language
|
self.ai_handler.main_pr_language = self.main_language
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
answer_str, question_str = self._get_user_answers()
|
answer_str, question_str = self._get_user_answers()
|
||||||
self.pr_description, self.pr_description_files = (
|
(
|
||||||
self.git_provider.get_pr_description(split_changes_walkthrough=True))
|
self.pr_description,
|
||||||
if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
|
self.pr_description_files,
|
||||||
get_settings().get("config.enable_ai_metadata", False)):
|
) = self.git_provider.get_pr_description(split_changes_walkthrough=True)
|
||||||
|
if (
|
||||||
|
self.pr_description_files
|
||||||
|
and get_settings().get("config.is_auto_command", False)
|
||||||
|
and get_settings().get("config.enable_ai_metadata", False)
|
||||||
|
):
|
||||||
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
|
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
|
||||||
get_logger().debug(f"AI metadata added to the this command")
|
get_logger().debug(f"AI metadata added to the this command")
|
||||||
else:
|
else:
|
||||||
@ -89,9 +110,11 @@ class PRReviewer:
|
|||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
"custom_labels": "",
|
"custom_labels": "",
|
||||||
"enable_custom_labels": get_settings().config.enable_custom_labels,
|
"enable_custom_labels": get_settings().config.enable_custom_labels,
|
||||||
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
|
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
|
||||||
"related_tickets": get_settings().get('related_tickets', []),
|
"related_tickets": get_settings().get('related_tickets', []),
|
||||||
'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False),
|
'duplicate_prompt_examples': get_settings().config.get(
|
||||||
|
'duplicate_prompt_examples', False
|
||||||
|
),
|
||||||
"date": datetime.datetime.now().strftime('%Y-%m-%d'),
|
"date": datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,7 +122,7 @@ class PRReviewer:
|
|||||||
self.git_provider.pr,
|
self.git_provider.pr,
|
||||||
self.vars,
|
self.vars,
|
||||||
get_settings().pr_review_prompt.system,
|
get_settings().pr_review_prompt.system,
|
||||||
get_settings().pr_review_prompt.user
|
get_settings().pr_review_prompt.user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_incremental(self, args: List[str]):
|
def parse_incremental(self, args: List[str]):
|
||||||
@ -117,7 +140,10 @@ class PRReviewer:
|
|||||||
get_logger().info(f"PR has no files: {self.pr_url}, skipping review")
|
get_logger().info(f"PR has no files: {self.pr_url}, skipping review")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.incremental.is_incremental and not self._can_run_incremental_review():
|
if (
|
||||||
|
self.incremental.is_incremental
|
||||||
|
and not self._can_run_incremental_review()
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve':
|
# if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve':
|
||||||
@ -126,27 +152,41 @@ class PRReviewer:
|
|||||||
# return None
|
# return None
|
||||||
|
|
||||||
get_logger().info(f'Reviewing PR: {self.pr_url} ...')
|
get_logger().info(f'Reviewing PR: {self.pr_url} ...')
|
||||||
relevant_configs = {'pr_reviewer': dict(get_settings().pr_reviewer),
|
relevant_configs = {
|
||||||
'config': dict(get_settings().config)}
|
'pr_reviewer': dict(get_settings().pr_reviewer),
|
||||||
|
'config': dict(get_settings().config),
|
||||||
|
}
|
||||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||||
|
|
||||||
# ticket extraction if exists
|
# ticket extraction if exists
|
||||||
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
|
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
|
||||||
|
|
||||||
if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set:
|
if (
|
||||||
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files")
|
self.incremental.is_incremental
|
||||||
|
and hasattr(self.git_provider, "unreviewed_files_set")
|
||||||
|
and not self.git_provider.unreviewed_files_set
|
||||||
|
):
|
||||||
|
get_logger().info(
|
||||||
|
f"Incremental review is enabled for {self.pr_url} but there are no new files"
|
||||||
|
)
|
||||||
previous_review_url = ""
|
previous_review_url = ""
|
||||||
if hasattr(self.git_provider, "previous_review"):
|
if hasattr(self.git_provider, "previous_review"):
|
||||||
previous_review_url = self.git_provider.previous_review.html_url
|
previous_review_url = self.git_provider.previous_review.html_url
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment(f"Incremental Review Skipped\n"
|
self.git_provider.publish_comment(
|
||||||
f"No files were changed since the [previous PR Review]({previous_review_url})")
|
f"Incremental Review Skipped\n"
|
||||||
|
f"No files were changed since the [previous PR Review]({previous_review_url})"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
|
if get_settings().config.publish_output and not get_settings().config.get(
|
||||||
|
'is_auto_command', False
|
||||||
|
):
|
||||||
self.git_provider.publish_comment("准备评审中...", is_temporary=True)
|
self.git_provider.publish_comment("准备评审中...", is_temporary=True)
|
||||||
|
|
||||||
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
|
await retry_with_fallback_models(
|
||||||
|
self._prepare_prediction, model_type=ModelType.REGULAR
|
||||||
|
)
|
||||||
if not self.prediction:
|
if not self.prediction:
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
return None
|
return None
|
||||||
@ -156,12 +196,19 @@ class PRReviewer:
|
|||||||
|
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
# publish the review
|
# publish the review
|
||||||
if get_settings().pr_reviewer.persistent_comment and not self.incremental.is_incremental:
|
if (
|
||||||
final_update_message = get_settings().pr_reviewer.final_update_message
|
get_settings().pr_reviewer.persistent_comment
|
||||||
self.git_provider.publish_persistent_comment(pr_review,
|
and not self.incremental.is_incremental
|
||||||
initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
|
):
|
||||||
update_header=True,
|
final_update_message = (
|
||||||
final_update_message=final_update_message, )
|
get_settings().pr_reviewer.final_update_message
|
||||||
|
)
|
||||||
|
self.git_provider.publish_persistent_comment(
|
||||||
|
pr_review,
|
||||||
|
initial_header=f"{PRReviewHeader.REGULAR.value} 🔍",
|
||||||
|
update_header=True,
|
||||||
|
final_update_message=final_update_message,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.git_provider.publish_comment(pr_review)
|
self.git_provider.publish_comment(pr_review)
|
||||||
|
|
||||||
@ -174,11 +221,13 @@ class PRReviewer:
|
|||||||
get_logger().error(f"Failed to review PR: {e}")
|
get_logger().error(f"Failed to review PR: {e}")
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str) -> None:
|
async def _prepare_prediction(self, model: str) -> None:
|
||||||
self.patches_diff = get_pr_diff(self.git_provider,
|
self.patches_diff = get_pr_diff(
|
||||||
self.token_handler,
|
self.git_provider,
|
||||||
model,
|
self.token_handler,
|
||||||
add_line_numbers_to_hunks=True,
|
model,
|
||||||
disable_extra_lines=False,)
|
add_line_numbers_to_hunks=True,
|
||||||
|
disable_extra_lines=False,
|
||||||
|
)
|
||||||
|
|
||||||
if self.patches_diff:
|
if self.patches_diff:
|
||||||
get_logger().debug(f"PR diff", diff=self.patches_diff)
|
get_logger().debug(f"PR diff", diff=self.patches_diff)
|
||||||
@ -201,14 +250,18 @@ class PRReviewer:
|
|||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)
|
get_settings().pr_review_prompt.system
|
||||||
|
).render(variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().pr_review_prompt.user
|
||||||
|
).render(variables)
|
||||||
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
temperature=get_settings().config.temperature,
|
temperature=get_settings().config.temperature,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
user=user_prompt
|
user=user_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -220,10 +273,20 @@ class PRReviewer:
|
|||||||
"""
|
"""
|
||||||
first_key = 'review'
|
first_key = 'review'
|
||||||
last_key = 'security_concerns'
|
last_key = 'security_concerns'
|
||||||
data = load_yaml(self.prediction.strip(),
|
data = load_yaml(
|
||||||
keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
|
self.prediction.strip(),
|
||||||
"relevant_file:", "relevant_line:", "suggestion:"],
|
keys_fix_yaml=[
|
||||||
first_key=first_key, last_key=last_key)
|
"ticket_compliance_check",
|
||||||
|
"estimated_effort_to_review_[1-5]:",
|
||||||
|
"security_concerns:",
|
||||||
|
"key_issues_to_review:",
|
||||||
|
"relevant_file:",
|
||||||
|
"relevant_line:",
|
||||||
|
"suggestion:",
|
||||||
|
],
|
||||||
|
first_key=first_key,
|
||||||
|
last_key=last_key,
|
||||||
|
)
|
||||||
github_action_output(data, 'review')
|
github_action_output(data, 'review')
|
||||||
|
|
||||||
# move data['review'] 'key_issues_to_review' key to the end of the dictionary
|
# move data['review'] 'key_issues_to_review' key to the end of the dictionary
|
||||||
@ -234,24 +297,38 @@ class PRReviewer:
|
|||||||
incremental_review_markdown_text = None
|
incremental_review_markdown_text = None
|
||||||
# Add incremental review section
|
# Add incremental review section
|
||||||
if self.incremental.is_incremental:
|
if self.incremental.is_incremental:
|
||||||
last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \
|
last_commit_url = (
|
||||||
f"{self.git_provider.incremental.first_new_commit_sha}"
|
f"{self.git_provider.get_pr_url()}/commits/"
|
||||||
|
f"{self.git_provider.incremental.first_new_commit_sha}"
|
||||||
|
)
|
||||||
incremental_review_markdown_text = f"Starting from commit {last_commit_url}"
|
incremental_review_markdown_text = f"Starting from commit {last_commit_url}"
|
||||||
|
|
||||||
markdown_text = convert_to_markdown_v2(data, self.git_provider.is_supported("gfm_markdown"),
|
markdown_text = convert_to_markdown_v2(
|
||||||
incremental_review_markdown_text,
|
data,
|
||||||
git_provider=self.git_provider,
|
self.git_provider.is_supported("gfm_markdown"),
|
||||||
files=self.git_provider.get_diff_files())
|
incremental_review_markdown_text,
|
||||||
|
git_provider=self.git_provider,
|
||||||
|
files=self.git_provider.get_diff_files(),
|
||||||
|
)
|
||||||
|
|
||||||
# Add help text if gfm_markdown is supported
|
# Add help text if gfm_markdown is supported
|
||||||
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_reviewer.enable_help_text:
|
if (
|
||||||
|
self.git_provider.is_supported("gfm_markdown")
|
||||||
|
and get_settings().pr_reviewer.enable_help_text
|
||||||
|
):
|
||||||
markdown_text += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
markdown_text += "<hr>\n\n<details> <summary><strong>💡 Tool usage guide:</strong></summary><hr> \n\n"
|
||||||
markdown_text += HelpMessage.get_review_usage_guide()
|
markdown_text += HelpMessage.get_review_usage_guide()
|
||||||
markdown_text += "\n</details>\n"
|
markdown_text += "\n</details>\n"
|
||||||
|
|
||||||
# Output the relevant configurations if enabled
|
# Output the relevant configurations if enabled
|
||||||
if get_settings().get('config', {}).get('output_relevant_configurations', False):
|
if (
|
||||||
markdown_text += show_relevant_configurations(relevant_section='pr_reviewer')
|
get_settings()
|
||||||
|
.get('config', {})
|
||||||
|
.get('output_relevant_configurations', False)
|
||||||
|
):
|
||||||
|
markdown_text += show_relevant_configurations(
|
||||||
|
relevant_section='pr_reviewer'
|
||||||
|
)
|
||||||
|
|
||||||
# Add custom labels from the review prediction (effort, security)
|
# Add custom labels from the review prediction (effort, security)
|
||||||
self.set_review_labels(data)
|
self.set_review_labels(data)
|
||||||
@ -306,34 +383,50 @@ class PRReviewer:
|
|||||||
if comment:
|
if comment:
|
||||||
self.git_provider.remove_comment(comment)
|
self.git_provider.remove_comment(comment)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to remove previous review comment, error: {e}")
|
get_logger().exception(
|
||||||
|
f"Failed to remove previous review comment, error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _can_run_incremental_review(self) -> bool:
|
def _can_run_incremental_review(self) -> bool:
|
||||||
"""Checks if we can run incremental review according the various configurations and previous review"""
|
"""Checks if we can run incremental review according the various configurations and previous review"""
|
||||||
# checking if running is auto mode but there are no new commits
|
# checking if running is auto mode but there are no new commits
|
||||||
if self.is_auto and not self.incremental.first_new_commit_sha:
|
if self.is_auto and not self.incremental.first_new_commit_sha:
|
||||||
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new commits")
|
get_logger().info(
|
||||||
|
f"Incremental review is enabled for {self.pr_url} but there are no new commits"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not hasattr(self.git_provider, "get_incremental_commits"):
|
if not hasattr(self.git_provider, "get_incremental_commits"):
|
||||||
get_logger().info(f"Incremental review is not supported for {get_settings().config.git_provider}")
|
get_logger().info(
|
||||||
|
f"Incremental review is not supported for {get_settings().config.git_provider}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
# checking if there are enough commits to start the review
|
# checking if there are enough commits to start the review
|
||||||
num_new_commits = len(self.incremental.commits_range)
|
num_new_commits = len(self.incremental.commits_range)
|
||||||
num_commits_threshold = get_settings().pr_reviewer.minimal_commits_for_incremental_review
|
num_commits_threshold = (
|
||||||
|
get_settings().pr_reviewer.minimal_commits_for_incremental_review
|
||||||
|
)
|
||||||
not_enough_commits = num_new_commits < num_commits_threshold
|
not_enough_commits = num_new_commits < num_commits_threshold
|
||||||
# checking if the commits are not too recent to start the review
|
# checking if the commits are not too recent to start the review
|
||||||
recent_commits_threshold = datetime.datetime.now() - datetime.timedelta(
|
recent_commits_threshold = datetime.datetime.now() - datetime.timedelta(
|
||||||
minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review
|
minutes=get_settings().pr_reviewer.minimal_minutes_for_incremental_review
|
||||||
)
|
)
|
||||||
last_seen_commit_date = (
|
last_seen_commit_date = (
|
||||||
self.incremental.last_seen_commit.commit.author.date if self.incremental.last_seen_commit else None
|
self.incremental.last_seen_commit.commit.author.date
|
||||||
|
if self.incremental.last_seen_commit
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
all_commits_too_recent = (
|
all_commits_too_recent = (
|
||||||
last_seen_commit_date > recent_commits_threshold if self.incremental.last_seen_commit else False
|
last_seen_commit_date > recent_commits_threshold
|
||||||
|
if self.incremental.last_seen_commit
|
||||||
|
else False
|
||||||
)
|
)
|
||||||
# check all the thresholds or just one to start the review
|
# check all the thresholds or just one to start the review
|
||||||
condition = any if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review else all
|
condition = (
|
||||||
|
any
|
||||||
|
if get_settings().pr_reviewer.require_all_thresholds_for_incremental_review
|
||||||
|
else all
|
||||||
|
)
|
||||||
if condition((not_enough_commits, all_commits_too_recent)):
|
if condition((not_enough_commits, all_commits_too_recent)):
|
||||||
get_logger().info(
|
get_logger().info(
|
||||||
f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:"
|
f"Incremental review is enabled for {self.pr_url} but didn't pass the threshold check to run:"
|
||||||
@ -348,31 +441,55 @@ class PRReviewer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if not get_settings().pr_reviewer.require_estimate_effort_to_review:
|
if not get_settings().pr_reviewer.require_estimate_effort_to_review:
|
||||||
get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output
|
get_settings().pr_reviewer.enable_review_labels_effort = (
|
||||||
|
False # we did not generate this output
|
||||||
|
)
|
||||||
if not get_settings().pr_reviewer.require_security_review:
|
if not get_settings().pr_reviewer.require_security_review:
|
||||||
get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output
|
get_settings().pr_reviewer.enable_review_labels_security = (
|
||||||
|
False # we did not generate this output
|
||||||
|
)
|
||||||
|
|
||||||
if (get_settings().pr_reviewer.enable_review_labels_security or
|
if (
|
||||||
get_settings().pr_reviewer.enable_review_labels_effort):
|
get_settings().pr_reviewer.enable_review_labels_security
|
||||||
|
or get_settings().pr_reviewer.enable_review_labels_effort
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
review_labels = []
|
review_labels = []
|
||||||
if get_settings().pr_reviewer.enable_review_labels_effort:
|
if get_settings().pr_reviewer.enable_review_labels_effort:
|
||||||
estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
|
estimated_effort = data['review'][
|
||||||
|
'estimated_effort_to_review_[1-5]'
|
||||||
|
]
|
||||||
estimated_effort_number = 0
|
estimated_effort_number = 0
|
||||||
if isinstance(estimated_effort, str):
|
if isinstance(estimated_effort, str):
|
||||||
try:
|
try:
|
||||||
estimated_effort_number = int(estimated_effort.split(',')[0])
|
estimated_effort_number = int(
|
||||||
|
estimated_effort.split(',')[0]
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}")
|
get_logger().warning(
|
||||||
|
f"Invalid estimated_effort value: {estimated_effort}"
|
||||||
|
)
|
||||||
elif isinstance(estimated_effort, int):
|
elif isinstance(estimated_effort, int):
|
||||||
estimated_effort_number = estimated_effort
|
estimated_effort_number = estimated_effort
|
||||||
else:
|
else:
|
||||||
get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}")
|
get_logger().warning(
|
||||||
|
f"Unexpected type for estimated_effort: {type(estimated_effort)}"
|
||||||
|
)
|
||||||
if 1 <= estimated_effort_number <= 5: # 1, because ...
|
if 1 <= estimated_effort_number <= 5: # 1, because ...
|
||||||
review_labels.append(f'Review effort {estimated_effort_number}/5')
|
review_labels.append(
|
||||||
if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review:
|
f'Review effort {estimated_effort_number}/5'
|
||||||
security_concerns = data['review']['security_concerns'] # yes, because ...
|
)
|
||||||
security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower()
|
if (
|
||||||
|
get_settings().pr_reviewer.enable_review_labels_security
|
||||||
|
and get_settings().pr_reviewer.require_security_review
|
||||||
|
):
|
||||||
|
security_concerns = data['review'][
|
||||||
|
'security_concerns'
|
||||||
|
] # yes, because ...
|
||||||
|
security_concerns_bool = (
|
||||||
|
'yes' in security_concerns.lower()
|
||||||
|
or 'true' in security_concerns.lower()
|
||||||
|
)
|
||||||
if security_concerns_bool:
|
if security_concerns_bool:
|
||||||
review_labels.append('Possible security concern')
|
review_labels.append('Possible security concern')
|
||||||
|
|
||||||
@ -381,17 +498,26 @@ class PRReviewer:
|
|||||||
current_labels = []
|
current_labels = []
|
||||||
get_logger().debug(f"Current labels:\n{current_labels}")
|
get_logger().debug(f"Current labels:\n{current_labels}")
|
||||||
if current_labels:
|
if current_labels:
|
||||||
current_labels_filtered = [label for label in current_labels if
|
current_labels_filtered = [
|
||||||
not label.lower().startswith('review effort') and not label.lower().startswith(
|
label
|
||||||
'possible security concern')]
|
for label in current_labels
|
||||||
|
if not label.lower().startswith('review effort')
|
||||||
|
and not label.lower().startswith('possible security concern')
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
current_labels_filtered = []
|
current_labels_filtered = []
|
||||||
new_labels = review_labels + current_labels_filtered
|
new_labels = review_labels + current_labels_filtered
|
||||||
if (current_labels or review_labels) and sorted(new_labels) != sorted(current_labels):
|
if (current_labels or review_labels) and sorted(new_labels) != sorted(
|
||||||
get_logger().info(f"Setting review labels:\n{review_labels + current_labels_filtered}")
|
current_labels
|
||||||
|
):
|
||||||
|
get_logger().info(
|
||||||
|
f"Setting review labels:\n{review_labels + current_labels_filtered}"
|
||||||
|
)
|
||||||
self.git_provider.publish_labels(new_labels)
|
self.git_provider.publish_labels(new_labels)
|
||||||
else:
|
else:
|
||||||
get_logger().info(f"Review labels are already set:\n{review_labels + current_labels_filtered}")
|
get_logger().info(
|
||||||
|
f"Review labels are already set:\n{review_labels + current_labels_filtered}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Failed to set review labels, error: {e}")
|
get_logger().error(f"Failed to set review labels, error: {e}")
|
||||||
|
|
||||||
@ -406,5 +532,7 @@ class PRReviewer:
|
|||||||
self.git_provider.publish_comment("自动批准 PR")
|
self.git_provider.publish_comment("自动批准 PR")
|
||||||
else:
|
else:
|
||||||
get_logger().info("Auto-approval option is disabled")
|
get_logger().info("Auto-approval option is disabled")
|
||||||
self.git_provider.publish_comment("PR-Agent 的自动批准选项已禁用. "
|
self.git_provider.publish_comment(
|
||||||
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)")
|
"PR-Agent 的自动批准选项已禁用. "
|
||||||
|
"你可以通过此设置打开 [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)"
|
||||||
|
)
|
||||||
|
|||||||
@ -24,12 +24,16 @@ class PRSimilarIssue:
|
|||||||
self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan
|
self.max_issues_to_scan = get_settings().pr_similar_issue.max_issues_to_scan
|
||||||
self.issue_url = issue_url
|
self.issue_url = issue_url
|
||||||
self.git_provider = get_git_provider()()
|
self.git_provider = get_git_provider()()
|
||||||
repo_name, issue_number = self.git_provider._parse_issue_url(issue_url.split('=')[-1])
|
repo_name, issue_number = self.git_provider._parse_issue_url(
|
||||||
|
issue_url.split('=')[-1]
|
||||||
|
)
|
||||||
self.git_provider.repo = repo_name
|
self.git_provider.repo = repo_name
|
||||||
self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
|
self.git_provider.repo_obj = self.git_provider.github_client.get_repo(repo_name)
|
||||||
self.token_handler = TokenHandler()
|
self.token_handler = TokenHandler()
|
||||||
repo_obj = self.git_provider.repo_obj
|
repo_obj = self.git_provider.repo_obj
|
||||||
repo_name_for_index = self.repo_name_for_index = repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
|
repo_name_for_index = self.repo_name_for_index = (
|
||||||
|
repo_obj.full_name.lower().replace('/', '-').replace('_/', '-')
|
||||||
|
)
|
||||||
index_name = self.index_name = "codium-ai-pr-agent-issues"
|
index_name = self.index_name = "codium-ai-pr-agent-issues"
|
||||||
|
|
||||||
if get_settings().pr_similar_issue.vectordb == "pinecone":
|
if get_settings().pr_similar_issue.vectordb == "pinecone":
|
||||||
@ -38,17 +42,30 @@ class PRSimilarIssue:
|
|||||||
import pinecone
|
import pinecone
|
||||||
from pinecone_datasets import Dataset, DatasetMetadata
|
from pinecone_datasets import Dataset, DatasetMetadata
|
||||||
except:
|
except:
|
||||||
raise Exception("Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb")
|
raise Exception(
|
||||||
|
"Please install 'pinecone' and 'pinecone_datasets' to use pinecone as vectordb"
|
||||||
|
)
|
||||||
# assuming pinecone api key and environment are set in secrets file
|
# assuming pinecone api key and environment are set in secrets file
|
||||||
try:
|
try:
|
||||||
api_key = get_settings().pinecone.api_key
|
api_key = get_settings().pinecone.api_key
|
||||||
environment = get_settings().pinecone.environment
|
environment = get_settings().pinecone.environment
|
||||||
except Exception:
|
except Exception:
|
||||||
if not self.cli_mode:
|
if not self.cli_mode:
|
||||||
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
|
(
|
||||||
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
|
repo_name,
|
||||||
issue_main.create_comment("Please set pinecone api key and environment in secrets file")
|
original_issue_number,
|
||||||
raise Exception("Please set pinecone api key and environment in secrets file")
|
) = self.git_provider._parse_issue_url(
|
||||||
|
self.issue_url.split('=')[-1]
|
||||||
|
)
|
||||||
|
issue_main = self.git_provider.repo_obj.get_issue(
|
||||||
|
original_issue_number
|
||||||
|
)
|
||||||
|
issue_main.create_comment(
|
||||||
|
"Please set pinecone api key and environment in secrets file"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
"Please set pinecone api key and environment in secrets file"
|
||||||
|
)
|
||||||
|
|
||||||
# check if index exists, and if repo is already indexed
|
# check if index exists, and if repo is already indexed
|
||||||
run_from_scratch = False
|
run_from_scratch = False
|
||||||
@ -69,7 +86,9 @@ class PRSimilarIssue:
|
|||||||
upsert = True
|
upsert = True
|
||||||
else:
|
else:
|
||||||
pinecone_index = pinecone.Index(index_name=index_name)
|
pinecone_index = pinecone.Index(index_name=index_name)
|
||||||
res = pinecone_index.fetch([f"example_issue_{repo_name_for_index}"]).to_dict()
|
res = pinecone_index.fetch(
|
||||||
|
[f"example_issue_{repo_name_for_index}"]
|
||||||
|
).to_dict()
|
||||||
if res["vectors"]:
|
if res["vectors"]:
|
||||||
upsert = False
|
upsert = False
|
||||||
|
|
||||||
@ -79,7 +98,9 @@ class PRSimilarIssue:
|
|||||||
get_logger().info('Getting issues...')
|
get_logger().info('Getting issues...')
|
||||||
issues = list(repo_obj.get_issues(state='all'))
|
issues = list(repo_obj.get_issues(state='all'))
|
||||||
get_logger().info('Done')
|
get_logger().info('Done')
|
||||||
self._update_index_with_issues(issues, repo_name_for_index, upsert=upsert)
|
self._update_index_with_issues(
|
||||||
|
issues, repo_name_for_index, upsert=upsert
|
||||||
|
)
|
||||||
else: # update index if needed
|
else: # update index if needed
|
||||||
pinecone_index = pinecone.Index(index_name=index_name)
|
pinecone_index = pinecone.Index(index_name=index_name)
|
||||||
issues_to_update = []
|
issues_to_update = []
|
||||||
@ -105,7 +126,9 @@ class PRSimilarIssue:
|
|||||||
|
|
||||||
if issues_to_update:
|
if issues_to_update:
|
||||||
get_logger().info(f'Updating index with {counter} new issues...')
|
get_logger().info(f'Updating index with {counter} new issues...')
|
||||||
self._update_index_with_issues(issues_to_update, repo_name_for_index, upsert=True)
|
self._update_index_with_issues(
|
||||||
|
issues_to_update, repo_name_for_index, upsert=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().info('No new issues to update')
|
get_logger().info('No new issues to update')
|
||||||
|
|
||||||
@ -133,7 +156,12 @@ class PRSimilarIssue:
|
|||||||
ingest = True
|
ingest = True
|
||||||
else:
|
else:
|
||||||
self.table = self.db[index_name]
|
self.table = self.db[index_name]
|
||||||
res = self.table.search().limit(len(self.table)).where(f"id='example_issue_{repo_name_for_index}'").to_list()
|
res = (
|
||||||
|
self.table.search()
|
||||||
|
.limit(len(self.table))
|
||||||
|
.where(f"id='example_issue_{repo_name_for_index}'")
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
get_logger().info("result: ", res)
|
get_logger().info("result: ", res)
|
||||||
if res[0].get("vector"):
|
if res[0].get("vector"):
|
||||||
ingest = False
|
ingest = False
|
||||||
@ -145,7 +173,9 @@ class PRSimilarIssue:
|
|||||||
issues = list(repo_obj.get_issues(state='all'))
|
issues = list(repo_obj.get_issues(state='all'))
|
||||||
get_logger().info('Done')
|
get_logger().info('Done')
|
||||||
|
|
||||||
self._update_table_with_issues(issues, repo_name_for_index, ingest=ingest)
|
self._update_table_with_issues(
|
||||||
|
issues, repo_name_for_index, ingest=ingest
|
||||||
|
)
|
||||||
else: # update table if needed
|
else: # update table if needed
|
||||||
issues_to_update = []
|
issues_to_update = []
|
||||||
issues_paginated_list = repo_obj.get_issues(state='all')
|
issues_paginated_list = repo_obj.get_issues(state='all')
|
||||||
@ -156,7 +186,12 @@ class PRSimilarIssue:
|
|||||||
issue_str, comments, number = self._process_issue(issue)
|
issue_str, comments, number = self._process_issue(issue)
|
||||||
issue_key = f"issue_{number}"
|
issue_key = f"issue_{number}"
|
||||||
issue_id = issue_key + "." + "issue"
|
issue_id = issue_key + "." + "issue"
|
||||||
res = self.table.search().limit(len(self.table)).where(f"id='{issue_id}'").to_list()
|
res = (
|
||||||
|
self.table.search()
|
||||||
|
.limit(len(self.table))
|
||||||
|
.where(f"id='{issue_id}'")
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
is_new_issue = True
|
is_new_issue = True
|
||||||
for r in res:
|
for r in res:
|
||||||
if r['metadata']['repo'] == repo_name_for_index:
|
if r['metadata']['repo'] == repo_name_for_index:
|
||||||
@ -170,14 +205,17 @@ class PRSimilarIssue:
|
|||||||
|
|
||||||
if issues_to_update:
|
if issues_to_update:
|
||||||
get_logger().info(f'Updating index with {counter} new issues...')
|
get_logger().info(f'Updating index with {counter} new issues...')
|
||||||
self._update_table_with_issues(issues_to_update, repo_name_for_index, ingest=True)
|
self._update_table_with_issues(
|
||||||
|
issues_to_update, repo_name_for_index, ingest=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
get_logger().info('No new issues to update')
|
get_logger().info('No new issues to update')
|
||||||
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
get_logger().info('Getting issue...')
|
get_logger().info('Getting issue...')
|
||||||
repo_name, original_issue_number = self.git_provider._parse_issue_url(self.issue_url.split('=')[-1])
|
repo_name, original_issue_number = self.git_provider._parse_issue_url(
|
||||||
|
self.issue_url.split('=')[-1]
|
||||||
|
)
|
||||||
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
|
issue_main = self.git_provider.repo_obj.get_issue(original_issue_number)
|
||||||
issue_str, comments, number = self._process_issue(issue_main)
|
issue_str, comments, number = self._process_issue(issue_main)
|
||||||
openai.api_key = get_settings().openai.key
|
openai.api_key = get_settings().openai.key
|
||||||
@ -193,10 +231,12 @@ class PRSimilarIssue:
|
|||||||
|
|
||||||
if get_settings().pr_similar_issue.vectordb == "pinecone":
|
if get_settings().pr_similar_issue.vectordb == "pinecone":
|
||||||
pinecone_index = pinecone.Index(index_name=self.index_name)
|
pinecone_index = pinecone.Index(index_name=self.index_name)
|
||||||
res = pinecone_index.query(embeds[0],
|
res = pinecone_index.query(
|
||||||
top_k=5,
|
embeds[0],
|
||||||
filter={"repo": self.repo_name_for_index},
|
top_k=5,
|
||||||
include_metadata=True).to_dict()
|
filter={"repo": self.repo_name_for_index},
|
||||||
|
include_metadata=True,
|
||||||
|
).to_dict()
|
||||||
|
|
||||||
for r in res['matches']:
|
for r in res['matches']:
|
||||||
# skip example issue
|
# skip example issue
|
||||||
@ -214,14 +254,20 @@ class PRSimilarIssue:
|
|||||||
if issue_number not in relevant_issues_number_list:
|
if issue_number not in relevant_issues_number_list:
|
||||||
relevant_issues_number_list.append(issue_number)
|
relevant_issues_number_list.append(issue_number)
|
||||||
if 'comment' in r["id"]:
|
if 'comment' in r["id"]:
|
||||||
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
|
relevant_comment_number_list.append(
|
||||||
|
int(r["id"].split('.')[1].split('_')[-1])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
relevant_comment_number_list.append(-1)
|
relevant_comment_number_list.append(-1)
|
||||||
score_list.append(str("{:.2f}".format(r['score'])))
|
score_list.append(str("{:.2f}".format(r['score'])))
|
||||||
get_logger().info('Done')
|
get_logger().info('Done')
|
||||||
|
|
||||||
elif get_settings().pr_similar_issue.vectordb == "lancedb":
|
elif get_settings().pr_similar_issue.vectordb == "lancedb":
|
||||||
res = self.table.search(embeds[0]).where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True).to_list()
|
res = (
|
||||||
|
self.table.search(embeds[0])
|
||||||
|
.where(f"metadata.repo='{self.repo_name_for_index}'", prefilter=True)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
|
||||||
for r in res:
|
for r in res:
|
||||||
# skip example issue
|
# skip example issue
|
||||||
@ -240,10 +286,12 @@ class PRSimilarIssue:
|
|||||||
relevant_issues_number_list.append(issue_number)
|
relevant_issues_number_list.append(issue_number)
|
||||||
|
|
||||||
if 'comment' in r["id"]:
|
if 'comment' in r["id"]:
|
||||||
relevant_comment_number_list.append(int(r["id"].split('.')[1].split('_')[-1]))
|
relevant_comment_number_list.append(
|
||||||
|
int(r["id"].split('.')[1].split('_')[-1])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
relevant_comment_number_list.append(-1)
|
relevant_comment_number_list.append(-1)
|
||||||
score_list.append(str("{:.2f}".format(1-r['_distance'])))
|
score_list.append(str("{:.2f}".format(1 - r['_distance'])))
|
||||||
get_logger().info('Done')
|
get_logger().info('Done')
|
||||||
|
|
||||||
get_logger().info('Publishing response...')
|
get_logger().info('Publishing response...')
|
||||||
@ -254,8 +302,12 @@ class PRSimilarIssue:
|
|||||||
title = issue.title
|
title = issue.title
|
||||||
url = issue.html_url
|
url = issue.html_url
|
||||||
if relevant_comment_number_list[i] != -1:
|
if relevant_comment_number_list[i] != -1:
|
||||||
url = list(issue.get_comments())[relevant_comment_number_list[i]].html_url
|
url = list(issue.get_comments())[
|
||||||
similar_issues_str += f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
|
relevant_comment_number_list[i]
|
||||||
|
].html_url
|
||||||
|
similar_issues_str += (
|
||||||
|
f"{i + 1}. **[{title}]({url})** (score={score_list[i]})\n\n"
|
||||||
|
)
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
response = issue_main.create_comment(similar_issues_str)
|
response = issue_main.create_comment(similar_issues_str)
|
||||||
get_logger().info(similar_issues_str)
|
get_logger().info(similar_issues_str)
|
||||||
@ -278,7 +330,7 @@ class PRSimilarIssue:
|
|||||||
example_issue_record = Record(
|
example_issue_record = Record(
|
||||||
id=f"example_issue_{repo_name_for_index}",
|
id=f"example_issue_{repo_name_for_index}",
|
||||||
text="example_issue",
|
text="example_issue",
|
||||||
metadata=Metadata(repo=repo_name_for_index)
|
metadata=Metadata(repo=repo_name_for_index),
|
||||||
)
|
)
|
||||||
corpus.append(example_issue_record)
|
corpus.append(example_issue_record)
|
||||||
|
|
||||||
@ -298,15 +350,20 @@ class PRSimilarIssue:
|
|||||||
issue_key = f"issue_{number}"
|
issue_key = f"issue_{number}"
|
||||||
username = issue.user.login
|
username = issue.user.login
|
||||||
created_at = str(issue.created_at)
|
created_at = str(issue.created_at)
|
||||||
if len(issue_str) < 8000 or \
|
if len(issue_str) < 8000 or self.token_handler.count_tokens(
|
||||||
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
|
issue_str
|
||||||
|
) < get_max_tokens(
|
||||||
|
MODEL
|
||||||
|
): # fast reject first
|
||||||
issue_record = Record(
|
issue_record = Record(
|
||||||
id=issue_key + "." + "issue",
|
id=issue_key + "." + "issue",
|
||||||
text=issue_str,
|
text=issue_str,
|
||||||
metadata=Metadata(repo=repo_name_for_index,
|
metadata=Metadata(
|
||||||
username=username,
|
repo=repo_name_for_index,
|
||||||
created_at=created_at,
|
username=username,
|
||||||
level=IssueLevel.ISSUE)
|
created_at=created_at,
|
||||||
|
level=IssueLevel.ISSUE,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
corpus.append(issue_record)
|
corpus.append(issue_record)
|
||||||
if comments:
|
if comments:
|
||||||
@ -316,15 +373,20 @@ class PRSimilarIssue:
|
|||||||
if num_words_comment < 10 or not isinstance(comment_body, str):
|
if num_words_comment < 10 or not isinstance(comment_body, str):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(comment_body) < 8000 or \
|
if (
|
||||||
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
|
len(comment_body) < 8000
|
||||||
|
or self.token_handler.count_tokens(comment_body)
|
||||||
|
< MAX_TOKENS[MODEL]
|
||||||
|
):
|
||||||
comment_record = Record(
|
comment_record = Record(
|
||||||
id=issue_key + ".comment_" + str(j + 1),
|
id=issue_key + ".comment_" + str(j + 1),
|
||||||
text=comment_body,
|
text=comment_body,
|
||||||
metadata=Metadata(repo=repo_name_for_index,
|
metadata=Metadata(
|
||||||
username=username, # use issue username for all comments
|
repo=repo_name_for_index,
|
||||||
created_at=created_at,
|
username=username, # use issue username for all comments
|
||||||
level=IssueLevel.COMMENT)
|
created_at=created_at,
|
||||||
|
level=IssueLevel.COMMENT,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
corpus.append(comment_record)
|
corpus.append(comment_record)
|
||||||
df = pd.DataFrame(corpus.dict()["documents"])
|
df = pd.DataFrame(corpus.dict()["documents"])
|
||||||
@ -355,7 +417,9 @@ class PRSimilarIssue:
|
|||||||
environment = get_settings().pinecone.environment
|
environment = get_settings().pinecone.environment
|
||||||
if not upsert:
|
if not upsert:
|
||||||
get_logger().info('Creating index from scratch...')
|
get_logger().info('Creating index from scratch...')
|
||||||
ds.to_pinecone_index(self.index_name, api_key=api_key, environment=environment)
|
ds.to_pinecone_index(
|
||||||
|
self.index_name, api_key=api_key, environment=environment
|
||||||
|
)
|
||||||
time.sleep(15) # wait for pinecone to finalize indexing before querying
|
time.sleep(15) # wait for pinecone to finalize indexing before querying
|
||||||
else:
|
else:
|
||||||
get_logger().info('Upserting index...')
|
get_logger().info('Upserting index...')
|
||||||
@ -374,7 +438,7 @@ class PRSimilarIssue:
|
|||||||
example_issue_record = Record(
|
example_issue_record = Record(
|
||||||
id=f"example_issue_{repo_name_for_index}",
|
id=f"example_issue_{repo_name_for_index}",
|
||||||
text="example_issue",
|
text="example_issue",
|
||||||
metadata=Metadata(repo=repo_name_for_index)
|
metadata=Metadata(repo=repo_name_for_index),
|
||||||
)
|
)
|
||||||
corpus.append(example_issue_record)
|
corpus.append(example_issue_record)
|
||||||
|
|
||||||
@ -394,15 +458,20 @@ class PRSimilarIssue:
|
|||||||
issue_key = f"issue_{number}"
|
issue_key = f"issue_{number}"
|
||||||
username = issue.user.login
|
username = issue.user.login
|
||||||
created_at = str(issue.created_at)
|
created_at = str(issue.created_at)
|
||||||
if len(issue_str) < 8000 or \
|
if len(issue_str) < 8000 or self.token_handler.count_tokens(
|
||||||
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
|
issue_str
|
||||||
|
) < get_max_tokens(
|
||||||
|
MODEL
|
||||||
|
): # fast reject first
|
||||||
issue_record = Record(
|
issue_record = Record(
|
||||||
id=issue_key + "." + "issue",
|
id=issue_key + "." + "issue",
|
||||||
text=issue_str,
|
text=issue_str,
|
||||||
metadata=Metadata(repo=repo_name_for_index,
|
metadata=Metadata(
|
||||||
username=username,
|
repo=repo_name_for_index,
|
||||||
created_at=created_at,
|
username=username,
|
||||||
level=IssueLevel.ISSUE)
|
created_at=created_at,
|
||||||
|
level=IssueLevel.ISSUE,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
corpus.append(issue_record)
|
corpus.append(issue_record)
|
||||||
if comments:
|
if comments:
|
||||||
@ -412,15 +481,20 @@ class PRSimilarIssue:
|
|||||||
if num_words_comment < 10 or not isinstance(comment_body, str):
|
if num_words_comment < 10 or not isinstance(comment_body, str):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(comment_body) < 8000 or \
|
if (
|
||||||
self.token_handler.count_tokens(comment_body) < MAX_TOKENS[MODEL]:
|
len(comment_body) < 8000
|
||||||
|
or self.token_handler.count_tokens(comment_body)
|
||||||
|
< MAX_TOKENS[MODEL]
|
||||||
|
):
|
||||||
comment_record = Record(
|
comment_record = Record(
|
||||||
id=issue_key + ".comment_" + str(j + 1),
|
id=issue_key + ".comment_" + str(j + 1),
|
||||||
text=comment_body,
|
text=comment_body,
|
||||||
metadata=Metadata(repo=repo_name_for_index,
|
metadata=Metadata(
|
||||||
username=username, # use issue username for all comments
|
repo=repo_name_for_index,
|
||||||
created_at=created_at,
|
username=username, # use issue username for all comments
|
||||||
level=IssueLevel.COMMENT)
|
created_at=created_at,
|
||||||
|
level=IssueLevel.COMMENT,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
corpus.append(comment_record)
|
corpus.append(comment_record)
|
||||||
df = pd.DataFrame(corpus.dict()["documents"])
|
df = pd.DataFrame(corpus.dict()["documents"])
|
||||||
@ -446,7 +520,9 @@ class PRSimilarIssue:
|
|||||||
|
|
||||||
if not ingest:
|
if not ingest:
|
||||||
get_logger().info('Creating table from scratch...')
|
get_logger().info('Creating table from scratch...')
|
||||||
self.table = self.db.create_table(self.index_name, data=df, mode="overwrite")
|
self.table = self.db.create_table(
|
||||||
|
self.index_name, data=df, mode="overwrite"
|
||||||
|
)
|
||||||
time.sleep(15)
|
time.sleep(15)
|
||||||
else:
|
else:
|
||||||
get_logger().info('Ingesting in Table...')
|
get_logger().info('Ingesting in Table...')
|
||||||
|
|||||||
@ -20,13 +20,20 @@ CHANGELOG_LINES = 50
|
|||||||
|
|
||||||
|
|
||||||
class PRUpdateChangelog:
|
class PRUpdateChangelog:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
def __init__(
|
||||||
|
self,
|
||||||
|
pr_url: str,
|
||||||
|
cli_mode=False,
|
||||||
|
args=None,
|
||||||
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler,
|
||||||
|
):
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
|
self.commit_changelog = (
|
||||||
|
get_settings().pr_update_changelog.push_changelog_changes
|
||||||
|
)
|
||||||
self._get_changelog_file() # self.changelog_file_str
|
self._get_changelog_file() # self.changelog_file_str
|
||||||
|
|
||||||
self.ai_handler = ai_handler()
|
self.ai_handler = ai_handler()
|
||||||
@ -47,15 +54,19 @@ class PRUpdateChangelog:
|
|||||||
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
|
"extra_instructions": get_settings().pr_update_changelog.extra_instructions,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
self.token_handler = TokenHandler(
|
||||||
self.vars,
|
self.git_provider.pr,
|
||||||
get_settings().pr_update_changelog_prompt.system,
|
self.vars,
|
||||||
get_settings().pr_update_changelog_prompt.user)
|
get_settings().pr_update_changelog_prompt.system,
|
||||||
|
get_settings().pr_update_changelog_prompt.user,
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
get_logger().info('Updating the changelog...')
|
get_logger().info('Updating the changelog...')
|
||||||
relevant_configs = {'pr_update_changelog': dict(get_settings().pr_update_changelog),
|
relevant_configs = {
|
||||||
'config': dict(get_settings().config)}
|
'pr_update_changelog': dict(get_settings().pr_update_changelog),
|
||||||
|
'config': dict(get_settings().config),
|
||||||
|
}
|
||||||
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
get_logger().debug("Relevant configs", artifacts=relevant_configs)
|
||||||
|
|
||||||
# currently only GitHub is supported for pushing changelog changes
|
# currently only GitHub is supported for pushing changelog changes
|
||||||
@ -74,13 +85,21 @@ class PRUpdateChangelog:
|
|||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True)
|
self.git_provider.publish_comment("准备变更日志更新中...", is_temporary=True)
|
||||||
|
|
||||||
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
|
await retry_with_fallback_models(
|
||||||
|
self._prepare_prediction, model_type=ModelType.WEAK
|
||||||
|
)
|
||||||
|
|
||||||
new_file_content, answer = self._prepare_changelog_update()
|
new_file_content, answer = self._prepare_changelog_update()
|
||||||
|
|
||||||
# Output the relevant configurations if enabled
|
# Output the relevant configurations if enabled
|
||||||
if get_settings().get('config', {}).get('output_relevant_configurations', False):
|
if (
|
||||||
answer += show_relevant_configurations(relevant_section='pr_update_changelog')
|
get_settings()
|
||||||
|
.get('config', {})
|
||||||
|
.get('output_relevant_configurations', False)
|
||||||
|
):
|
||||||
|
answer += show_relevant_configurations(
|
||||||
|
relevant_section='pr_update_changelog'
|
||||||
|
)
|
||||||
|
|
||||||
get_logger().debug(f"PR output", artifact=answer)
|
get_logger().debug(f"PR output", artifact=answer)
|
||||||
|
|
||||||
@ -89,7 +108,9 @@ class PRUpdateChangelog:
|
|||||||
if self.commit_changelog:
|
if self.commit_changelog:
|
||||||
self._push_changelog_update(new_file_content, answer)
|
self._push_changelog_update(new_file_content, answer)
|
||||||
else:
|
else:
|
||||||
self.git_provider.publish_comment(f"**Changelog updates:** 🔄\n\n{answer}")
|
self.git_provider.publish_comment(
|
||||||
|
f"**Changelog updates:** 🔄\n\n{answer}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str):
|
async def _prepare_prediction(self, model: str):
|
||||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||||
@ -106,10 +127,18 @@ class PRUpdateChangelog:
|
|||||||
if get_settings().pr_update_changelog.add_pr_link:
|
if get_settings().pr_update_changelog.add_pr_link:
|
||||||
variables["pr_link"] = self.git_provider.get_pr_url()
|
variables["pr_link"] = self.git_provider.get_pr_url()
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
system_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.system).render(variables)
|
system_prompt = environment.from_string(
|
||||||
user_prompt = environment.from_string(get_settings().pr_update_changelog_prompt.user).render(variables)
|
get_settings().pr_update_changelog_prompt.system
|
||||||
|
).render(variables)
|
||||||
|
user_prompt = environment.from_string(
|
||||||
|
get_settings().pr_update_changelog_prompt.user
|
||||||
|
).render(variables)
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(
|
response, finish_reason = await self.ai_handler.chat_completion(
|
||||||
model=model, system=system_prompt, user=user_prompt, temperature=get_settings().config.temperature)
|
model=model,
|
||||||
|
system=system_prompt,
|
||||||
|
user=user_prompt,
|
||||||
|
temperature=get_settings().config.temperature,
|
||||||
|
)
|
||||||
|
|
||||||
# post-process the response
|
# post-process the response
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
@ -134,8 +163,10 @@ class PRUpdateChangelog:
|
|||||||
new_file_content = answer
|
new_file_content = answer
|
||||||
|
|
||||||
if not self.commit_changelog:
|
if not self.commit_changelog:
|
||||||
answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \
|
answer += (
|
||||||
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
|
"\n\n\n>to commit the new content to the CHANGELOG.md file, please type:"
|
||||||
|
"\n>'/update_changelog --pr_update_changelog.push_changelog_changes=true'\n"
|
||||||
|
)
|
||||||
|
|
||||||
return new_file_content, answer
|
return new_file_content, answer
|
||||||
|
|
||||||
@ -163,8 +194,7 @@ class PRUpdateChangelog:
|
|||||||
self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}")
|
self.git_provider.publish_comment(f"**Changelog updates: 🔄**\n\n{answer}")
|
||||||
|
|
||||||
def _get_default_changelog(self):
|
def _get_default_changelog(self):
|
||||||
example_changelog = \
|
example_changelog = """
|
||||||
"""
|
|
||||||
Example:
|
Example:
|
||||||
## <current_date>
|
## <current_date>
|
||||||
|
|
||||||
|
|||||||
@ -7,14 +7,15 @@ from utils.pr_agent.log import get_logger
|
|||||||
|
|
||||||
# Compile the regex pattern once, outside the function
|
# Compile the regex pattern once, outside the function
|
||||||
GITHUB_TICKET_PATTERN = re.compile(
|
GITHUB_TICKET_PATTERN = re.compile(
|
||||||
r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
|
r'(https://github[^/]+/[^/]+/[^/]+/issues/\d+)|(\b(\w+)/(\w+)#(\d+)\b)|(#\d+)'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_jira_tickets(text):
|
def find_jira_tickets(text):
|
||||||
# Regular expression patterns for JIRA tickets
|
# Regular expression patterns for JIRA tickets
|
||||||
patterns = [
|
patterns = [
|
||||||
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
|
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
|
||||||
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket
|
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b', # JIRA URL or just the ticket
|
||||||
]
|
]
|
||||||
|
|
||||||
tickets = set()
|
tickets = set()
|
||||||
@ -32,7 +33,9 @@ def find_jira_tickets(text):
|
|||||||
return list(tickets)
|
return list(tickets)
|
||||||
|
|
||||||
|
|
||||||
def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url_html='https://github.com'):
|
def extract_ticket_links_from_pr_description(
|
||||||
|
pr_description, repo_path, base_url_html='https://github.com'
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Extract all ticket links from PR description
|
Extract all ticket links from PR description
|
||||||
"""
|
"""
|
||||||
@ -46,19 +49,27 @@ def extract_ticket_links_from_pr_description(pr_description, repo_path, base_url
|
|||||||
github_tickets.add(match[0])
|
github_tickets.add(match[0])
|
||||||
elif match[1]: # Shorthand notation match: owner/repo#issue_number
|
elif match[1]: # Shorthand notation match: owner/repo#issue_number
|
||||||
owner, repo, issue_number = match[2], match[3], match[4]
|
owner, repo, issue_number = match[2], match[3], match[4]
|
||||||
github_tickets.add(f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}')
|
github_tickets.add(
|
||||||
|
f'{base_url_html.strip("/")}/{owner}/{repo}/issues/{issue_number}'
|
||||||
|
)
|
||||||
else: # #123 format
|
else: # #123 format
|
||||||
issue_number = match[5][1:] # remove #
|
issue_number = match[5][1:] # remove #
|
||||||
if issue_number.isdigit() and len(issue_number) < 5 and repo_path:
|
if issue_number.isdigit() and len(issue_number) < 5 and repo_path:
|
||||||
github_tickets.add(f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}')
|
github_tickets.add(
|
||||||
|
f'{base_url_html.strip("/")}/{repo_path}/issues/{issue_number}'
|
||||||
|
)
|
||||||
|
|
||||||
if len(github_tickets) > 3:
|
if len(github_tickets) > 3:
|
||||||
get_logger().info(f"Too many tickets found in PR description: {len(github_tickets)}")
|
get_logger().info(
|
||||||
|
f"Too many tickets found in PR description: {len(github_tickets)}"
|
||||||
|
)
|
||||||
# Limit the number of tickets to 3
|
# Limit the number of tickets to 3
|
||||||
github_tickets = set(list(github_tickets)[:3])
|
github_tickets = set(list(github_tickets)[:3])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error extracting tickets error= {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Error extracting tickets error= {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
return list(github_tickets)
|
return list(github_tickets)
|
||||||
|
|
||||||
@ -68,19 +79,26 @@ async def extract_tickets(git_provider):
|
|||||||
try:
|
try:
|
||||||
if isinstance(git_provider, GithubProvider):
|
if isinstance(git_provider, GithubProvider):
|
||||||
user_description = git_provider.get_user_description()
|
user_description = git_provider.get_user_description()
|
||||||
tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html)
|
tickets = extract_ticket_links_from_pr_description(
|
||||||
|
user_description, git_provider.repo, git_provider.base_url_html
|
||||||
|
)
|
||||||
tickets_content = []
|
tickets_content = []
|
||||||
|
|
||||||
if tickets:
|
if tickets:
|
||||||
|
|
||||||
for ticket in tickets:
|
for ticket in tickets:
|
||||||
repo_name, original_issue_number = git_provider._parse_issue_url(ticket)
|
repo_name, original_issue_number = git_provider._parse_issue_url(
|
||||||
|
ticket
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
issue_main = git_provider.repo_obj.get_issue(original_issue_number)
|
issue_main = git_provider.repo_obj.get_issue(
|
||||||
|
original_issue_number
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error getting main issue: {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Error getting main issue: {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
issue_body_str = issue_main.body or ""
|
issue_body_str = issue_main.body or ""
|
||||||
@ -93,47 +111,66 @@ async def extract_tickets(git_provider):
|
|||||||
sub_issues = git_provider.fetch_sub_issues(ticket)
|
sub_issues = git_provider.fetch_sub_issues(ticket)
|
||||||
for sub_issue_url in sub_issues:
|
for sub_issue_url in sub_issues:
|
||||||
try:
|
try:
|
||||||
sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url)
|
(
|
||||||
sub_issue = git_provider.repo_obj.get_issue(sub_issue_number)
|
sub_repo,
|
||||||
|
sub_issue_number,
|
||||||
|
) = git_provider._parse_issue_url(sub_issue_url)
|
||||||
|
sub_issue = git_provider.repo_obj.get_issue(
|
||||||
|
sub_issue_number
|
||||||
|
)
|
||||||
|
|
||||||
sub_body = sub_issue.body or ""
|
sub_body = sub_issue.body or ""
|
||||||
if len(sub_body) > MAX_TICKET_CHARACTERS:
|
if len(sub_body) > MAX_TICKET_CHARACTERS:
|
||||||
sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..."
|
sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..."
|
||||||
|
|
||||||
sub_issues_content.append({
|
sub_issues_content.append(
|
||||||
'ticket_url': sub_issue_url,
|
{
|
||||||
'title': sub_issue.title,
|
'ticket_url': sub_issue_url,
|
||||||
'body': sub_body
|
'title': sub_issue.title,
|
||||||
})
|
'body': sub_body,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}")
|
get_logger().warning(
|
||||||
|
f"Failed to fetch sub-issue content for {sub_issue_url}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}")
|
get_logger().warning(
|
||||||
|
f"Failed to fetch sub-issues for {ticket}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# Extract labels
|
# Extract labels
|
||||||
labels = []
|
labels = []
|
||||||
try:
|
try:
|
||||||
for label in issue_main.labels:
|
for label in issue_main.labels:
|
||||||
labels.append(label.name if hasattr(label, 'name') else label)
|
labels.append(
|
||||||
|
label.name if hasattr(label, 'name') else label
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error extracting labels error= {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Error extracting labels error= {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
tickets_content.append({
|
tickets_content.append(
|
||||||
'ticket_id': issue_main.number,
|
{
|
||||||
'ticket_url': ticket,
|
'ticket_id': issue_main.number,
|
||||||
'title': issue_main.title,
|
'ticket_url': ticket,
|
||||||
'body': issue_body_str,
|
'title': issue_main.title,
|
||||||
'labels': ", ".join(labels),
|
'body': issue_body_str,
|
||||||
'sub_issues': sub_issues_content # Store sub-issues content
|
'labels': ", ".join(labels),
|
||||||
})
|
'sub_issues': sub_issues_content, # Store sub-issues content
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return tickets_content
|
return tickets_content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error(f"Error extracting tickets error= {e}",
|
get_logger().error(
|
||||||
artifact={"traceback": traceback.format_exc()})
|
f"Error extracting tickets error= {e}",
|
||||||
|
artifact={"traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def extract_and_cache_pr_tickets(git_provider, vars):
|
async def extract_and_cache_pr_tickets(git_provider, vars):
|
||||||
@ -154,8 +191,10 @@ async def extract_and_cache_pr_tickets(git_provider, vars):
|
|||||||
|
|
||||||
related_tickets.append(ticket)
|
related_tickets.append(ticket)
|
||||||
|
|
||||||
get_logger().info("Extracted tickets and sub-issues from PR description",
|
get_logger().info(
|
||||||
artifact={"tickets": related_tickets})
|
"Extracted tickets and sub-issues from PR description",
|
||||||
|
artifact={"tickets": related_tickets},
|
||||||
|
)
|
||||||
|
|
||||||
vars['related_tickets'] = related_tickets
|
vars['related_tickets'] = related_tickets
|
||||||
get_settings().set('related_tickets', related_tickets)
|
get_settings().set('related_tickets', related_tickets)
|
||||||
|
|||||||
13
config.ini
Normal file
13
config.ini
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[BASE]
|
||||||
|
; 是否开启debug模式 0或1
|
||||||
|
DEBUG = 0
|
||||||
|
|
||||||
|
[DATABASE]
|
||||||
|
; 默认采用sqlite,线上需替换为pg
|
||||||
|
DEFAULT = pg
|
||||||
|
; postgres配置
|
||||||
|
DB_NAME = pr_manager
|
||||||
|
DB_USER = admin
|
||||||
|
DB_PASSWORD = admin123456
|
||||||
|
DB_HOST = 110.40.30.95
|
||||||
|
DB_PORT = 5432
|
||||||
@ -12,11 +12,18 @@ https://docs.djangoproject.com/en/5.1/ref/settings/
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import configparser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
CONFIG_NAME = BASE_DIR / "config.ini"
|
||||||
|
|
||||||
|
# 加载配置文件: 开发可加载config.local.ini
|
||||||
|
_config = configparser.ConfigParser()
|
||||||
|
_config.read(CONFIG_NAME, encoding="utf-8")
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(BASE_DIR, "apps"))
|
sys.path.insert(0, os.path.join(BASE_DIR, "apps"))
|
||||||
sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
|
sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
|
||||||
|
|
||||||
@ -27,7 +34,7 @@ sys.path.insert(1, os.path.join(BASE_DIR, "apps/utils"))
|
|||||||
SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76"
|
SECRET_KEY = "django-insecure-$r6lfcq8rev&&=chw259o$0o7t-!!%clc2ahs3xg$^z+gkms76"
|
||||||
|
|
||||||
# SECURITY WARNING: don't run with debug turned on in production!
|
# SECURITY WARNING: don't run with debug turned on in production!
|
||||||
DEBUG = False
|
DEBUG = bool(int(_config["BASE"].get("DEBUG", "1")))
|
||||||
|
|
||||||
ALLOWED_HOSTS = ["*"]
|
ALLOWED_HOSTS = ["*"]
|
||||||
|
|
||||||
@ -44,7 +51,7 @@ INSTALLED_APPS = [
|
|||||||
"django.contrib.messages",
|
"django.contrib.messages",
|
||||||
"django.contrib.staticfiles",
|
"django.contrib.staticfiles",
|
||||||
"public",
|
"public",
|
||||||
"pr"
|
"pr",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 配置安全秘钥
|
# 配置安全秘钥
|
||||||
@ -68,8 +75,7 @@ ROOT_URLCONF = "pr_manager.urls"
|
|||||||
TEMPLATES = [
|
TEMPLATES = [
|
||||||
{
|
{
|
||||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||||
"DIRS": [BASE_DIR / 'templates']
|
"DIRS": [BASE_DIR / 'templates'],
|
||||||
,
|
|
||||||
"APP_DIRS": True,
|
"APP_DIRS": True,
|
||||||
"OPTIONS": {
|
"OPTIONS": {
|
||||||
"context_processors": [
|
"context_processors": [
|
||||||
@ -89,12 +95,22 @@ WSGI_APPLICATION = "pr_manager.wsgi.application"
|
|||||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases
|
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases
|
||||||
|
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
"default": {
|
"pg": {
|
||||||
|
"ENGINE": "django.db.backends.postgresql",
|
||||||
|
"NAME": _config["DATABASE"].get("DB_NAME", "chat_ai_v2"),
|
||||||
|
"USER": _config["DATABASE"].get("DB_USER", "admin"),
|
||||||
|
"PASSWORD": _config["DATABASE"].get("DB_PASSWORD", "admin123456"),
|
||||||
|
"HOST": _config["DATABASE"].get("DB_HOST", "124.222.222.101"),
|
||||||
|
"PORT": int(_config["DATABASE"].get("DB_PORT", "5432")),
|
||||||
|
},
|
||||||
|
"sqlite": {
|
||||||
"ENGINE": "django.db.backends.sqlite3",
|
"ENGINE": "django.db.backends.sqlite3",
|
||||||
"NAME": BASE_DIR / "db.sqlite3",
|
"NAME": BASE_DIR / "db.sqlite3",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DATABASES["default"] = DATABASES[_config["DATABASE"].get("DEFAULT", "sqlite")]
|
||||||
|
|
||||||
|
|
||||||
# Password validation
|
# Password validation
|
||||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators
|
# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user