Skip to content

Commit

Permalink
Merge pull request #574 from Calcium-Ion/channel-tag
Browse files Browse the repository at this point in the history
feat: 初步集成渠道标签分组功能
  • Loading branch information
Calcium-Ion authored Nov 30, 2024
2 parents 821f3a7 + 6693072 commit 48abfd0
Show file tree
Hide file tree
Showing 11 changed files with 1,476 additions and 587 deletions.
125 changes: 122 additions & 3 deletions controller/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,37 @@ func GetAllChannels(c *gin.Context) {
})
return
}
tags := make(map[string]bool)
channelData := make([]*model.Channel, 0, len(channels))
tagChannels := make([]*model.Channel, 0)
for _, channel := range channels {
channelTag := channel.GetTag()
if channelTag != "" && !tags[channelTag] {
tags[channelTag] = true
tagChannel, err := model.GetChannelsByTag(channelTag)
if err == nil {
tagChannels = append(tagChannels, tagChannel...)
}
} else {
channelData = append(channelData, channel)
}
}
for i, channel := range tagChannels {
find := false
for _, can := range channelData {
if channel.Id == can.Id {
find = true
break
}
}
if !find {
channelData = append(channelData, tagChannels[i])
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channels,
"data": channelData,
})
return
}
Expand Down Expand Up @@ -144,8 +171,8 @@ func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group, modelKeyword)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
Expand Down Expand Up @@ -279,6 +306,98 @@ func DeleteDisabledChannel(c *gin.Context) {
return
}

type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
}

func DisableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

func EnableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

func EditTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
if channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

type ChannelBatch struct {
Ids []int `json:"ids"`
}
Expand Down
32 changes: 26 additions & 6 deletions model/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import (
)

type Ability struct {
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
Tag *string `json:"tag" gorm:"index"`
}

func GetGroupModels(group string) []string {
Expand Down Expand Up @@ -149,6 +150,7 @@ func (channel *Channel) AddAbilities() error {
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
Tag: channel.Tag,
}
abilities = append(abilities, ability)
}
Expand Down Expand Up @@ -190,6 +192,24 @@ func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}

func UpdateAbilityStatusByTag(tag string, status bool) error {
return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
}

func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
ability := Ability{}
if newTag != nil {
ability.Tag = newTag
}
if priority != nil {
ability.Priority = priority
}
if weight != nil {
ability.Weight = *weight
}
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
}

func FixAbility() (int, error) {
var channelIds []int
count := 0
Expand Down
95 changes: 93 additions & 2 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
Tag *string `json:"tag" gorm:"index"`
}

func (channel *Channel) GetModels() []string {
Expand Down Expand Up @@ -61,6 +62,17 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes)
}

func (channel *Channel) GetTag() string {
if channel.Tag == nil {
return ""
}
return *channel.Tag
}

func (channel *Channel) SetTag(tag string) {
channel.Tag = &tag
}

func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
Expand All @@ -87,7 +99,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
return channels, err
}

func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
func GetChannelsByTag(tag string) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("tag = ?", tag).Find(&channels).Error
return channels, err
}

func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
var channels []*Channel
keyCol := "`key`"
groupCol := "`group`"
Expand All @@ -100,6 +118,11 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
modelsCol = `"models"`
}

order := "priority desc"
if idSort {
order = "id desc"
}

// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol)

Expand All @@ -122,7 +145,7 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
}

// 执行查询
err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error
err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -288,6 +311,74 @@ func UpdateChannelStatusById(id int, status int, reason string) {

}

func EnableChannelByTag(tag string) error {
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
if err != nil {
return err
}
err = UpdateAbilityStatusByTag(tag, true)
return err
}

func DisableChannelByTag(tag string) error {
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
if err != nil {
return err
}
err = UpdateAbilityStatusByTag(tag, false)
return err
}

func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
updateData := Channel{}
shouldReCreateAbilities := false
updatedTag := tag
// 如果 newTag 不为空且不等于 tag,则更新 tag
if newTag != nil && *newTag != tag {
updateData.Tag = newTag
updatedTag = *newTag
}
if modelMapping != nil && *modelMapping != "" {
updateData.ModelMapping = modelMapping
}
if models != nil && *models != "" {
shouldReCreateAbilities = true
updateData.Models = *models
}
if group != nil && *group != "" {
shouldReCreateAbilities = true
updateData.Group = *group
}
if priority != nil {
updateData.Priority = priority
}
if weight != nil {
updateData.Weight = weight
}

err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
if err != nil {
return err
}
if shouldReCreateAbilities {
channels, err := GetChannelsByTag(updatedTag)
if err == nil {
for _, channel := range channels {
err = channel.UpdateAbilities()
if err != nil {
common.SysError("failed to update abilities: " + err.Error())
}
}
}
} else {
err := UpdateAbilityByTag(tag, newTag, priority, weight)
if err != nil {
return err
}
}
return nil
}

func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
Expand Down
5 changes: 5 additions & 0 deletions relay/channel/openai/relay-openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
shouldSendLastResp = false
}
}
for _, choice := range lastStreamResponse.Choices {
if choice.FinishReason != nil {
shouldSendLastResp = true
}
}
}
if shouldSendLastResp {
service.StringData(c, lastStreamData)
Expand Down
3 changes: 3 additions & 0 deletions router/api-router.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.POST("/tag/disabled", controller.DisableTagChannels)
channelRoute.POST("/tag/enabled", controller.EnableTagChannels)
channelRoute.PUT("/tag", controller.EditTagChannels)
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)
Expand Down
Loading

0 comments on commit 48abfd0

Please sign in to comment.