303 lines
10 KiB
Python
303 lines
10 KiB
Python
from abc import ABC, abstractmethod
|
|
from threading import Thread
|
|
|
|
from django.http import JsonResponse
|
|
|
|
from pr import models
|
|
from utils.pr_agent import cli
|
|
from utils.pr_agent.config_loader import get_settings
|
|
from utils import constant
|
|
|
|
|
|
class GitProvider(ABC):
|
|
@abstractmethod
|
|
def get_project_config(self, project_id, git_type):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_merge_request(
|
|
self,
|
|
request_data,
|
|
git_url,
|
|
access_token,
|
|
api_base,
|
|
api_key,
|
|
llm_model,
|
|
project_commands,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def run_command(self, mr_url, project_commands):
|
|
pass
|
|
|
|
|
|
class GitLabProvider(GitProvider):
|
|
@staticmethod
|
|
def check_secret(request_headers, project_secret, git_type):
|
|
"""
|
|
检查密钥
|
|
:param git_type:
|
|
:param request_headers:
|
|
:param project_secret:
|
|
:return:
|
|
"""
|
|
token = request_headers.get("X-Gitlab-Token")
|
|
if token != project_secret:
|
|
return JsonResponse(status=403, data={"error": "Invalid token"})
|
|
|
|
@staticmethod
|
|
def get_project_id(request_json, git_type):
|
|
"""
|
|
获取项目ID
|
|
:param git_type:
|
|
:param request_json:
|
|
:return:
|
|
"""
|
|
return request_json.get("project", {}).get("id")
|
|
|
|
def get_project_config(self, project_id, git_type):
|
|
"""
|
|
实现GitLab项目配置获取逻辑
|
|
:param git_type:
|
|
:param project_id:
|
|
:return:
|
|
"""
|
|
project_config = models.ProjectConfig.objects.filter(
|
|
project_id=project_id, git_config__git_type=0
|
|
).first()
|
|
git_config = project_config.git_config # Gitlab
|
|
if not project_config:
|
|
return JsonResponse(status=400, data={"error": "Project not found"})
|
|
if not project_config.is_enable:
|
|
return JsonResponse(status=400, data={"error": "Project is disabled"})
|
|
|
|
return {
|
|
"api_base": git_config.pr_ai.api_base,
|
|
"api_key": git_config.pr_ai.api_key,
|
|
"llm_model": git_config.pr_ai.llm_model,
|
|
"git_url": git_config.git_url,
|
|
"git_type": "gitlab",
|
|
"access_token": git_config.access_token,
|
|
"project_secret": project_config.project_secret,
|
|
"commands": project_config.commands.split(","),
|
|
"project_id": project_config.id,
|
|
}
|
|
|
|
def get_merge_request(
|
|
self,
|
|
request_data,
|
|
git_url,
|
|
access_token,
|
|
api_base,
|
|
api_key,
|
|
llm_model,
|
|
project_commands,
|
|
):
|
|
"""
|
|
实现GitLab Merge Request获取逻辑
|
|
:param project_commands:
|
|
:param llm_model:
|
|
:param api_key:
|
|
:param api_base:
|
|
:param access_token:
|
|
:param git_url:
|
|
:param request_data:
|
|
:return:
|
|
"""
|
|
if request_data.get('object_kind') == 'merge_request':
|
|
merge_request = request_data.get('object_attributes', {})
|
|
if merge_request.get('state') == 'opened':
|
|
mr_url = merge_request.get('url')
|
|
mr_action = merge_request.get('action')
|
|
get_settings().set("config.git_provider", "gitlab")
|
|
get_settings().set("gitlab.url", git_url)
|
|
get_settings().set("gitlab.personal_access_token", access_token)
|
|
get_settings().set("openai.api_base", api_base)
|
|
get_settings().set("openai.key", api_key)
|
|
get_settings().set("llm.model", llm_model)
|
|
if mr_action == "update":
|
|
old_rev = merge_request.get("oldrev")
|
|
new_rev = merge_request.get("newrev")
|
|
if old_rev == new_rev:
|
|
return JsonResponse(
|
|
status=200, data={"status": "ignored (no code change)"}
|
|
)
|
|
self.run_command(mr_url, project_commands)
|
|
# 数据库留存
|
|
return JsonResponse(status=200, data={"status": "review started"})
|
|
return JsonResponse(
|
|
status=400,
|
|
data={"error": "Merge request URL not found or action not open"},
|
|
)
|
|
|
|
@staticmethod
|
|
def save_pr_agent_log(request_data, project_id):
|
|
"""
|
|
记录pr agent日志
|
|
:param request_data:
|
|
:param project_id:
|
|
:return:
|
|
"""
|
|
if request_data.get('object_attributes', {}).get(
|
|
"source_branch"
|
|
) and request_data.get('object_attributes', {}).get("target_branch"):
|
|
models.ProjectHistory.objects.create(
|
|
project_id=project_id,
|
|
project_url=request_data.get("project", {}).get("web_url"),
|
|
mr_url=request_data.get('object_attributes', {}).get("url"),
|
|
source_branch=request_data.get('object_attributes', {}).get(
|
|
"source_branch"
|
|
),
|
|
target_branch=request_data.get('object_attributes', {}).get(
|
|
"target_branch"
|
|
),
|
|
mr_title=request_data.get('object_attributes', {}).get("title"),
|
|
source_data=request_data,
|
|
)
|
|
|
|
def run_command(self, mr_url, project_commands):
|
|
"""
|
|
自定义指令
|
|
:param mr_url:
|
|
:param project_commands:
|
|
:return:
|
|
"""
|
|
threads = []
|
|
for cmd in project_commands:
|
|
if cmd not in [cmd[1] for cmd in constant.DEFAULT_COMMANDS]:
|
|
continue
|
|
t = Thread(target=cli.run_command, args=(mr_url, cmd))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
|
|
class GiteaProvider(GitProvider):
|
|
@staticmethod
|
|
def check_secret(request_headers, project_secret, git_type):
|
|
"""
|
|
检查密钥
|
|
:param git_type:
|
|
:param request_headers:
|
|
:param project_secret:
|
|
:return:
|
|
"""
|
|
token = request_headers.get("Authorization")
|
|
if token != project_secret:
|
|
return JsonResponse(status=403, data={"error": "Invalid token"})
|
|
|
|
@staticmethod
|
|
def get_project_id(request_json, git_type):
|
|
"""
|
|
获取项目ID
|
|
:param git_type:
|
|
:param request_json:
|
|
:return:
|
|
"""
|
|
return request_json.get("repository", {}).get("id")
|
|
|
|
def get_project_config(self, project_id, git_type):
|
|
"""
|
|
实现Gitea项目配置获取逻辑
|
|
:param git_type:
|
|
:param project_id:
|
|
:return:
|
|
"""
|
|
project_config = models.ProjectConfig.objects.filter(
|
|
project_id=project_id, git_config__git_type=2
|
|
).first()
|
|
git_config = project_config.git_config # Gitea
|
|
if not project_config:
|
|
return JsonResponse(status=400, data={"error": "Project not found"})
|
|
if not project_config.is_enable:
|
|
return JsonResponse(status=400, data={"error": "Project is disabled"})
|
|
|
|
return {
|
|
"api_base": git_config.pr_ai.api_base,
|
|
"api_key": git_config.pr_ai.api_key,
|
|
"llm_model": git_config.pr_ai.llm_model,
|
|
"git_url": git_config.git_url,
|
|
"git_type": "gitea",
|
|
"access_token": git_config.access_token,
|
|
"project_secret": project_config.project_secret,
|
|
"commands": project_config.commands.split(","),
|
|
"project_id": project_config.id,
|
|
}
|
|
|
|
def get_merge_request(
|
|
self,
|
|
request_data,
|
|
git_url,
|
|
access_token,
|
|
api_base,
|
|
api_key,
|
|
llm_model,
|
|
project_commands,
|
|
):
|
|
"""
|
|
实现GitLab Merge Request获取逻辑
|
|
:param project_commands:
|
|
:param llm_model:
|
|
:param api_key:
|
|
:param api_base:
|
|
:param access_token:
|
|
:param git_url:
|
|
:param request_data:
|
|
:return:
|
|
"""
|
|
if request_data.get("action") == "opened":
|
|
merge_request = request_data.get('pull_request', {})
|
|
mr_url = merge_request.get("url")
|
|
get_settings().set("config.git_provider", "gitea")
|
|
get_settings().set("gitlab.url", git_url)
|
|
get_settings().set("gitlab.personal_access_token", access_token)
|
|
get_settings().set("openai.api_base", api_base)
|
|
get_settings().set("openai.key", api_key)
|
|
get_settings().set("llm.model", llm_model)
|
|
self.run_command(mr_url, project_commands)
|
|
# 数据库留存
|
|
return JsonResponse(status=200, data={"status": "review started"})
|
|
return JsonResponse(
|
|
status=400,
|
|
data={"error": "Merge request URL not found or action not open"},
|
|
)
|
|
|
|
@staticmethod
|
|
def save_pr_agent_log(request_data, project_id):
|
|
"""
|
|
记录pr agent日志
|
|
:param request_data:
|
|
:param project_id:
|
|
:return:
|
|
"""
|
|
if request_data.get('pull_request', {}).get(
|
|
"head"
|
|
) and request_data.get('pull_request', {}).get("base"):
|
|
models.ProjectHistory.objects.create(
|
|
project_id=project_id,
|
|
project_url=request_data.get("repository", {}).get("html_url"),
|
|
mr_url=request_data.get('pull_request', {}).get("url"),
|
|
source_branch=request_data.get('pull_request', {}).get(
|
|
"head"
|
|
).get("label"),
|
|
target_branch=request_data.get('pull_request', {}).get(
|
|
"base"
|
|
).get("label"),
|
|
mr_title=request_data.get('pull_request', {}).get("title"),
|
|
source_data=request_data,
|
|
)
|
|
|
|
def run_command(self, mr_url, project_commands):
|
|
"""
|
|
自定义指令
|
|
:param mr_url:
|
|
:param project_commands:
|
|
:return:
|
|
"""
|
|
threads = []
|
|
for cmd in project_commands:
|
|
if cmd not in [cmd[1] for cmd in constant.DEFAULT_COMMANDS]:
|
|
continue
|
|
t = Thread(target=cli.run_command, args=(mr_url, cmd))
|
|
threads.append(t)
|
|
t.start() |