Skip to content

Commit

Permalink
fix: ai proxy (#6038)
Browse files Browse the repository at this point in the history
* ai-proxy client proto use clientId field to auto check permission

* session: set min_len of name&scene to 2; modelId is optional when update

* only admin can paging client

* use default model if session's modelId is empty
  • Loading branch information
sfwn authored Sep 11, 2023
1 parent 027bb57 commit b0ad278
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 27 deletions.
14 changes: 7 additions & 7 deletions api/proto/apps/aiproxy/client/client.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ service ClientService {

rpc Get(ClientGetRequest) returns (Client) {
option(google.api.http) = {
get: "/api/ai-proxy/clients/{id}"
get: "/api/ai-proxy/clients/{clientId}"
};
}

rpc Delete(ClientDeleteRequest) returns (common.VoidResponse) {
option(google.api.http) = {
delete: "/api/ai-proxy/clients/{id}"
delete: "/api/ai-proxy/clients/{clientId}"
};
}

rpc Update(ClientUpdateRequest) returns (Client) {
option(google.api.http) = {
put: "/api/ai-proxy/clients/{id}"
put: "/api/ai-proxy/clients/{clientId}"
};
}

Expand Down Expand Up @@ -63,15 +63,15 @@ message ClientCreateRequest {
}

message ClientGetRequest {
string id = 1 [(validate.rules).string = {len: 36}];
string clientId = 1 [(validate.rules).string = {len: 36}];
}

message ClientDeleteRequest {
string id = 1 [(validate.rules).string = {len: 36}];
string clientId = 1 [(validate.rules).string = {len: 36}];
}

message ClientUpdateRequest {
string id = 1 [(validate.rules).string = {len: 36}];
string clientId = 1 [(validate.rules).string = {len: 36}];
string name = 2 [(validate.rules).string.min_len = 4, (validate.rules).string.max_len = 191];
string desc = 3 [(validate.rules).string.max_len = 1024];
string accessKeyId = 4 [(validate.rules).string = {min_len: 32, max_len: 32}];
Expand All @@ -83,7 +83,7 @@ message ClientPagingRequest {
int64 pageNum = 1;
int64 pageSize = 2;

repeated string ids = 3;
repeated string ids = 3 [(validate.rules).repeated.items.string = {len: 36}];
string name = 4;
repeated string accessKeyIds = 5;
}
Expand Down
16 changes: 8 additions & 8 deletions api/proto/apps/aiproxy/session/session.proto
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ message SessionCreateRequest {
string clientId = 1 [(validate.rules).string = {ignore_empty: true, len: 36}];
string promptId = 2 [(validate.rules).string = {max_len: 36}];
string modelId = 3 [(validate.rules).string = {ignore_empty: true, len: 36}];
string scene = 4 [(validate.rules).string = {min_len: 4, max_len: 191}];
string userId = 5;
string name = 6 [(validate.rules).string = {min_len: 4, max_len: 191}];
string userId = 4;
string scene = 5 [(validate.rules).string = {min_len: 2, max_len: 191}];
string name = 6 [(validate.rules).string = {min_len: 2, max_len: 191}];
string topic = 7 [(validate.rules).string = {min_len: 1}];
uint64 numOfCtxMsg = 8;
double temperature = 9 [(validate.rules).double = {gte: 0, lte: 2}];
Expand All @@ -133,10 +133,10 @@ message SessionUpdateRequest {
string id = 1 [(validate.rules).string = {len: 36}];
string clientId = 2 [(validate.rules).string = {len: 36}];
string promptId = 3 [(validate.rules).string = {ignore_empty: true, len: 36}];
string modelId = 4 [(validate.rules).string = {len: 36}];
string modelId = 4 [(validate.rules).string = {ignore_empty: true, len: 36}];
string userId = 5 [(validate.rules).string = {max_len: 191}];
string scene = 6 [(validate.rules).string = {min_len: 4, max_len: 191}];
string name = 7 [(validate.rules).string = {min_len: 4, max_len: 191}];
string scene = 6 [(validate.rules).string = {min_len: 2, max_len: 191}];
string name = 7 [(validate.rules).string = {min_len: 2, max_len: 191}];
string topic = 8 [(validate.rules).string = {min_len: 1}];
uint64 numOfCtxMsg = 9;
double temperature = 10 [(validate.rules).double = {gte: 0, lte: 2}];
Expand All @@ -148,8 +148,8 @@ message SessionPagingRequest {
string promptId = 2 [(validate.rules).string = {ignore_empty: true, len: 36}];
string modelId = 3 [(validate.rules).string = {ignore_empty: true, len: 36}];
string userId = 4 [(validate.rules).string = {max_len: 191}];
string scene = 5 [(validate.rules).string = {ignore_empty: true, min_len: 4, max_len: 191}];
string name = 6 [(validate.rules).string = {ignore_empty: true, min_len: 4, max_len: 191}];
string scene = 5 [(validate.rules).string = {ignore_empty: true, min_len: 2, max_len: 191}];
string name = 6 [(validate.rules).string = {ignore_empty: true, min_len: 2, max_len: 191}];
bool isArchived = 7;
uint64 pageSize = 8 [(validate.rules).uint64 = {gte: 1}];
uint64 pageNum = 9 [(validate.rules).uint64 = {gte: 1, lte: 1000}];
Expand Down
17 changes: 10 additions & 7 deletions internal/apps/ai-proxy/filters/context/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
return reverseproxy.Intercept, err
}
session = _session
sessionModel, err := q.ModelClient().Get(ctx, &modelpb.ModelGetRequest{Id: session.ModelId})
if err != nil {
l.Errorf("failed to get model, id: %s, err: %v", session.ModelId, err)
http.Error(w, "ModelId is invalid", http.StatusBadRequest)
return reverseproxy.Intercept, err
if session.ModelId != "" {
sessionModel, err := q.ModelClient().Get(ctx, &modelpb.ModelGetRequest{Id: session.ModelId})
if err != nil {
l.Errorf("failed to get model, id: %s, err: %v", session.ModelId, err)
http.Error(w, "ModelId is invalid", http.StatusBadRequest)
return reverseproxy.Intercept, err
}
model = sessionModel
}
model = sessionModel
} else if headerModelId != "" {
// get from model header
if headerModelId == "" {
Expand All @@ -115,7 +117,8 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
return reverseproxy.Intercept, err
}
model = headerModel
} else {
}
if model == nil {
// get client default model
clientPbMeta := metadata.FromProtobuf(client.Metadata)
clientMeta, err := clientPbMeta.ToClientMeta()
Expand Down
2 changes: 1 addition & 1 deletion internal/apps/ai-proxy/handlers/permission/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var CheckClientPerm = CheckPermissions(
&MethodPermission{Method: clientpb.ClientServiceServer.Create, OnlyAdmin: true},
&MethodPermission{Method: clientpb.ClientServiceServer.Get, AdminOrAk: true},
&MethodPermission{Method: clientpb.ClientServiceServer.Update, OnlyAdmin: true},
&MethodPermission{Method: clientpb.ClientServiceServer.Paging, AdminOrAk: true},
&MethodPermission{Method: clientpb.ClientServiceServer.Paging, OnlyAdmin: true},
&MethodPermission{Method: clientpb.ClientServiceServer.Delete, OnlyAdmin: true},
)

Expand Down
8 changes: 4 additions & 4 deletions internal/apps/ai-proxy/models/client/dbclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func (dbClient *DBClient) Create(ctx context.Context, req *pb.ClientCreateReques
}

func (dbClient *DBClient) Get(ctx context.Context, req *pb.ClientGetRequest) (*pb.Client, error) {
c := &Client{BaseModel: common.BaseModelWithID(req.Id)}
c := &Client{BaseModel: common.BaseModelWithID(req.ClientId)}
if err := dbClient.DB.Model(c).First(c).Error; err != nil {
return nil, err
}
return c.ToProtobuf(), nil
}

func (dbClient *DBClient) Delete(ctx context.Context, req *pb.ClientDeleteRequest) (*commonpb.VoidResponse, error) {
c := &Client{BaseModel: common.BaseModelWithID(req.Id)}
c := &Client{BaseModel: common.BaseModelWithID(req.ClientId)}
sql := dbClient.DB.Model(c).Delete(c)
if sql.Error != nil {
return nil, sql.Error
Expand All @@ -72,7 +72,7 @@ func (dbClient *DBClient) Delete(ctx context.Context, req *pb.ClientDeleteReques

func (dbClient *DBClient) Update(ctx context.Context, req *pb.ClientUpdateRequest) (*pb.Client, error) {
c := &Client{
BaseModel: common.BaseModelWithID(req.Id),
BaseModel: common.BaseModelWithID(req.ClientId),
Name: req.Name,
Desc: req.Desc,
AccessKeyID: req.AccessKeyId,
Expand All @@ -86,7 +86,7 @@ func (dbClient *DBClient) Update(ctx context.Context, req *pb.ClientUpdateReques
if sql.RowsAffected != 1 {
return nil, gorm.ErrRecordNotFound
}
return dbClient.Get(ctx, &pb.ClientGetRequest{Id: req.Id})
return dbClient.Get(ctx, &pb.ClientGetRequest{ClientId: req.ClientId})
}

func (dbClient *DBClient) Paging(ctx context.Context, req *pb.ClientPagingRequest) (*pb.ClientPagingResponse, error) {
Expand Down

0 comments on commit b0ad278

Please sign in to comment.