Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sql injection in api module #733

Merged
merged 1 commit into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 62 additions & 70 deletions modules/api/app/controller/alarm/alarm_events_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ package alarm

import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
h "github.com/open-falcon/falcon-plus/modules/api/app/helper"
alm "github.com/open-falcon/falcon-plus/modules/api/app/model/alarm"
"strings"
Expand Down Expand Up @@ -51,72 +51,52 @@ func (input APIGetAlarmListsInputs) checkInputsContain() error {
return nil
}

func (s APIGetAlarmListsInputs) collectFilters() string {
tmp := []string{}
func (s APIGetAlarmListsInputs) collectDBFilters(database *gorm.DB, tableName string, columns []string) *gorm.DB {
filterDB := database.Table(tableName)
// nil columns mean select all columns
if columns != nil && len(columns) != 0 {
filterDB = filterDB.Select(columns)
}
if s.StartTime != 0 {
tmp = append(tmp, fmt.Sprintf("timestamp >= FROM_UNIXTIME(%v)", s.StartTime))
filterDB = filterDB.Where("timestamp >= FROM_UNIXTIME(?)", s.StartTime)
}
if s.EndTime != 0 {
tmp = append(tmp, fmt.Sprintf("timestamp <= FROM_UNIXTIME(%v)", s.EndTime))
filterDB = filterDB.Where("timestamp <= FROM_UNIXTIME(?)", s.EndTime)
}
if s.Priority != -1 {
tmp = append(tmp, fmt.Sprintf("priority = %d", s.Priority))
filterDB = filterDB.Where("priority = ?", s.Priority)
}
if s.Status != "" {
status := ""
statusTmp := strings.Split(s.Status, ",")
for indx, n := range statusTmp {
if indx == 0 {
status = fmt.Sprintf(" status = '%s' ", n)
} else {
status = fmt.Sprintf(" %s OR status = '%s' ", status, n)
}
}
status = fmt.Sprintf("( %s )", status)
tmp = append(tmp, status)
filterDB = filterDB.Where("status in (?)", statusTmp)
}
if s.ProcessStatus != "" {
pstatus := ""
pstatusTmp := strings.Split(s.ProcessStatus, ",")
for indx, n := range pstatusTmp {
if indx == 0 {
pstatus = fmt.Sprintf(" process_status = '%s' ", n)
} else {
pstatus = fmt.Sprintf(" %s OR process_status = '%s' ", pstatus, n)
}
}
pstatus = fmt.Sprintf("( %s )", pstatus)
tmp = append(tmp, pstatus)
filterDB = filterDB.Where("process_status in (?)", pstatusTmp)
}
if s.Metrics != "" {
tmp = append(tmp, fmt.Sprintf("metrics regexp '%s'", s.Metrics))
filterDB = filterDB.Where("metric regexp ?", s.Metrics)
}
if s.EventId != "" {
tmp = append(tmp, fmt.Sprintf("id = '%s'", s.EventId))
filterDB = filterDB.Where("id = ?", s.EventId)
}
if s.Endpoints != nil && len(s.Endpoints) != 0 {
for i, ep := range s.Endpoints {
s.Endpoints[i] = fmt.Sprintf("'%s'", ep)
}
tmp = append(tmp, fmt.Sprintf("endpoint in (%s)", strings.Join(s.Endpoints, ", ")))
filterDB = filterDB.Where("endpoint in (?)", s.Endpoints)
}
if s.StrategyId != 0 {
tmp = append(tmp, fmt.Sprintf("strategy_id = %d", s.StrategyId))
filterDB = filterDB.Where("strategy_id = ?", s.StrategyId)
}
if s.TemplateId != 0 {
tmp = append(tmp, fmt.Sprintf("template_id = %d", s.TemplateId))
filterDB = filterDB.Where("template_id = ?", s.TemplateId)
}
filterStrTmp := strings.Join(tmp, " AND ")
if filterStrTmp != "" {
filterStrTmp = fmt.Sprintf("WHERE %s", filterStrTmp)
}
return filterStrTmp
return filterDB
}

func AlarmLists(c *gin.Context) {
var inputs APIGetAlarmListsInputs
//set default
inputs.Page = -1
inputs.Limit = -1
inputs.Priority = -1
if err := c.Bind(&inputs); err != nil {
h.JSONR(c, badstatus, err)
Expand All @@ -126,31 +106,44 @@ func AlarmLists(c *gin.Context) {
h.JSONR(c, badstatus, err)
return
}
filterCollector := inputs.collectFilters()
//for get correct table name
f := alm.EventCases{}
alarmDB := inputs.collectDBFilters(db.Alarm, f.TableName(), nil)
cevens := []alm.EventCases{}
perparedSql := ""
//if no specific, will give return first 2000 records
if inputs.Page == -1 {
if inputs.Limit >= 2000 || inputs.Limit == 0 {
inputs.Limit = 2000
}
perparedSql = fmt.Sprintf("select * from %s %s order by timestamp DESC limit %d", f.TableName(), filterCollector, inputs.Limit)
if inputs.Page == -1 && inputs.Limit == -1{
inputs.Limit = 2000
alarmDB = alarmDB.Order("timestamp DESC").Limit(inputs.Limit)
} else if inputs.Limit == -1 {
// set page but not set limit
h.JSONR(c, badstatus, errors.New("You set page but skip limit params, please check your input"))
return
} else {
// set limit but not set page
if inputs.Page == -1 {
// limit invalid
if inputs.Limit <= 0 {
h.JSONR(c, badstatus, errors.New("limit or page can not set to 0 or less than 0"))
return
}
// set default page
inputs.Page = 1
} else {
// set page and limit
// page or limit invalid
if inputs.Page <= 0 || inputs.Limit <= 0 {
h.JSONR(c, badstatus, errors.New("limit or page can not set to 0 or less than 0"))
return
}
}
//set the max limit of each page
if inputs.Limit >= 50 {
inputs.Limit = 50
}

// if page stands for step page
// {"page":0} for actual page 1
// step = page * limit
step := inputs.Page * inputs.Limit

perparedSql = fmt.Sprintf("select * from %s %s order by timestamp DESC limit %d,%d", f.TableName(), filterCollector, step, inputs.Limit)
step := (inputs.Page -1) * inputs.Limit
alarmDB = alarmDB.Order("timestamp DESC").Offset(step).Limit(inputs.Limit)
}
db.Alarm.Raw(perparedSql).Find(&cevens)
alarmDB.Find(&cevens)
h.JSONR(c, cevens)
}

Expand All @@ -166,26 +159,25 @@ type APIEventsGetInputs struct {
Page int `json:"page" form:"page"`
}

func (s APIEventsGetInputs) collectFilters() string {
tmp := []string{}
filterStrTmp := ""
func (s APIEventsGetInputs) collectDBFilters(database *gorm.DB, tableName string, columns []string) *gorm.DB {
filterDB := database.Table(tableName)
// nil columns mean select all columns
if columns != nil && len(columns) != 0 {
filterDB = filterDB.Select(columns)
}
if s.StartTime != 0 {
tmp = append(tmp, fmt.Sprintf("timestamp >= FROM_UNIXTIME(%v)", s.StartTime))
filterDB = filterDB.Where("timestamp >= FROM_UNIXTIME(?)", s.StartTime)
}
if s.EndTime != 0 {
tmp = append(tmp, fmt.Sprintf("timestamp <= FROM_UNIXTIME(%v)", s.EndTime))
filterDB = filterDB.Where("timestamp <= FROM_UNIXTIME(?)", s.EndTime)
}
if s.EventId != "" {
tmp = append(tmp, fmt.Sprintf("event_caseId = '%s'", s.EventId))
filterDB = filterDB.Where("event_caseId = ?", s.EventId)
}
if s.Status == 0 || s.Status == 1 {
tmp = append(tmp, fmt.Sprintf("status = %d", s.Status))
}
if len(tmp) != 0 {
filterStrTmp = strings.Join(tmp, " AND ")
filterStrTmp = fmt.Sprintf("WHERE %s", filterStrTmp)
filterDB = filterDB.Where("status = ?", s.Status)
}
return filterStrTmp
return filterDB
}

func EventsGet(c *gin.Context) {
Expand All @@ -195,14 +187,14 @@ func EventsGet(c *gin.Context) {
h.JSONR(c, badstatus, err)
return
}
filterCollector := inputs.collectFilters()
//for get correct table name
f := alm.Events{}
eventDB := inputs.collectDBFilters(db.Alarm, f.TableName(), []string{"id", "step", "event_caseId", "cond", "status", "timestamp"})
evens := []alm.Events{}
if inputs.Limit == 0 || inputs.Limit >= 50 {
if inputs.Limit <= 0 || inputs.Limit >= 50 {
inputs.Limit = 50
}
perparedSql := fmt.Sprintf("select id, step, event_caseId, cond, status, timestamp from %s %s order by timestamp DESC limit %d,%d", f.TableName(), filterCollector, inputs.Page, inputs.Limit)
db.Alarm.Raw(perparedSql).Scan(&evens)
step := (inputs.Page -1) * inputs.Limit
eventDB.Order("timestamp DESC").Offset(step).Limit(inputs.Limit).Scan(&evens)
h.JSONR(c, evens)
}
38 changes: 16 additions & 22 deletions modules/api/app/controller/alarm/alarm_notes_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
h "github.com/open-falcon/falcon-plus/modules/api/app/helper"
alm "github.com/open-falcon/falcon-plus/modules/api/app/model/alarm"
"strings"
"time"
)

Expand All @@ -45,25 +45,25 @@ func (input APIGetNotesOfAlarmInputs) checkInputsContain() error {
return nil
}

func (s APIGetNotesOfAlarmInputs) collectFilters() string {
tmp := []string{}
func (s APIGetNotesOfAlarmInputs) collectDBFilters(database *gorm.DB, tableName string, columns []string) *gorm.DB {
filterDB := database.Table(tableName)
// nil columns mean select all columns
if columns != nil && len(columns) != 0 {
filterDB = filterDB.Select(columns)
}
if s.StartTime != 0 {
tmp = append(tmp, fmt.Sprintf("timestamp >= FROM_UNIXTIME(%v)", s.StartTime))
filterDB = filterDB.Where("timestamp >= FROM_UNIXTIME(?)", s.StartTime)
}
if s.EndTime != 0 {
tmp = append(tmp, fmt.Sprintf("timestamp <= FROM_UNIXTIME(%v)", s.EndTime))
filterDB = filterDB.Where("timestamp <= FROM_UNIXTIME(?)", s.EndTime)
}
if s.Status != "" {
tmp = append(tmp, fmt.Sprintf("status = '%s'", s.Status))
filterDB = filterDB.Where("status = ?", s.Status)
}
if s.EventId != "" {
tmp = append(tmp, fmt.Sprintf("event_caseId = '%s'", s.EventId))
}
filterStrTmp := strings.Join(tmp, " AND ")
if filterStrTmp != "" {
filterStrTmp = fmt.Sprintf("WHERE %s", filterStrTmp)
filterDB = filterDB.Where("event_caseId = ?", s.EventId)
}
return filterStrTmp
return filterDB
}

type APIGetNotesOfAlarmOuput struct {
Expand All @@ -85,21 +85,15 @@ func GetNotesOfAlarm(c *gin.Context) {
h.JSONR(c, badstatus, err)
return
}
filterCollector := inputs.collectFilters()
//for get correct table name
f := alm.EventNote{}
noteDB := inputs.collectDBFilters(db.Alarm, f.TableName(), []string{"id", "event_caseId", "note", "case_id", "status", "timestamp", "user_id"})
notes := []alm.EventNote{}
if inputs.Limit == 0 || inputs.Limit >= 50 {
if inputs.Limit <= 0 || inputs.Limit >= 50 {
inputs.Limit = 50
}
perparedSql := fmt.Sprintf(
"select id, event_caseId, note, case_id, status, timestamp, user_id from %s %s order by timestamp DESC limit %d,%d",
f.TableName(),
filterCollector,
inputs.Page,
inputs.Limit,
)
db.Alarm.Raw(perparedSql).Scan(&notes)
step := (inputs.Page - 1) * inputs.Limit
noteDB.Order("timestamp DESC").Offset(step).Limit(inputs.Limit).Scan(&notes)
output := []APIGetNotesOfAlarmOuput{}
for _, n := range notes {
output = append(output, APIGetNotesOfAlarmOuput{
Expand Down
12 changes: 6 additions & 6 deletions modules/api/app/controller/expression/expression_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func GetExpressionList(c *gin.Context) {
var dt *gorm.DB
expressions := []f.Expression{}
if limit != -1 && page != -1 {
dt = db.Falcon.Raw(fmt.Sprintf("SELECT * from expression limit %d,%d", page, limit)).Scan(&expressions)
dt = db.Falcon.Raw("SELECT * from expression limit ?,?", page, limit).Scan(&expressions)
} else {
dt = db.Falcon.Find(&expressions)
}
Expand All @@ -66,8 +66,8 @@ func GetExpression(c *gin.Context) {
h.JSONR(c, badstatus, err)
return
}
expression := f.Expression{ID: int64(eid)}
if dt := db.Falcon.Find(&expression); dt.Error != nil {
expression := f.Expression{}
if dt := db.Falcon.Where("id = ?", eid).Find(&expression); dt.Error != nil {
h.JSONR(c, badstatus, dt.Error)
return
}
Expand Down Expand Up @@ -287,9 +287,9 @@ func DeleteExpression(c *gin.Context) {
}
tx := db.Falcon.Begin()
user, _ := h.GetUser(c)
expression := f.Expression{ID: int64(eid)}
expression := f.Expression{}
if !user.IsAdmin() {
tx.Find(&expression)
tx.Where("id = ?", eid).Find(&expression)
if expression.CreateUser != user.Name {
h.JSONR(c, badstatus, "You don't have permission!")
tx.Rollback()
Expand All @@ -302,7 +302,7 @@ func DeleteExpression(c *gin.Context) {
tx.Rollback()
return
}
if dt := tx.Delete(&expression); dt.Error != nil {
if dt := tx.Where("id = ?", eid).Delete(&expression); dt.Error != nil {
h.JSONR(c, badstatus, dt.Error)
tx.Rollback()
return
Expand Down
9 changes: 3 additions & 6 deletions modules/api/app/controller/graph/grafana_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package graph

import (
"fmt"
"regexp"
"strings"

Expand All @@ -25,7 +24,6 @@ import (
cmodel "github.com/open-falcon/falcon-plus/common/model"
h "github.com/open-falcon/falcon-plus/modules/api/app/helper"
m "github.com/open-falcon/falcon-plus/modules/api/app/model/graph"
u "github.com/open-falcon/falcon-plus/modules/api/app/utils"
)

type APIGrafanaMainQueryInputs struct {
Expand Down Expand Up @@ -182,11 +180,10 @@ func responseCounterRegexp(regexpKey string) (result []APIGrafanaMainQueryOutput
if len(hostIds) == 0 {
return
}
idConcact, _ := u.ArrInt64ToString(hostIds)
//for get right table name
countHelp := m.EndpointCounter{}
counters := []m.EndpointCounter{}
db.Graph.Table(countHelp.TableName()).Where(fmt.Sprintf("endpoint_id IN (%s) AND counter regexp '%s'", idConcact, counter)).Scan(&counters)
db.Graph.Table(countHelp.TableName()).Where("endpoint_id IN (?)", hostIds).Where("counter regexp ?", counter).Scan(&counters)
//if not any counter matched
if len(counters) == 0 {
return
Expand Down Expand Up @@ -255,9 +252,9 @@ func GrafanaRender(c *gin.Context) {
counters := []m.EndpointCounter{}
hostIds := findEndpointIdByEndpointList(hosts)
if flag {
db.Graph.Table(ecHelp.TableName()).Select("distinct counter").Where(fmt.Sprintf("endpoint_id IN (%s) AND counter = '%s'", u.ArrInt64ToStringMust(hostIds), counter)).Scan(&counters)
db.Graph.Table(ecHelp.TableName()).Select("distinct counter").Where("endpoint_id IN (?)", hostIds).Where("counter = ?", counter).Scan(&counters)
} else {
db.Graph.Table(ecHelp.TableName()).Select("distinct counter").Where(fmt.Sprintf("endpoint_id IN (%s) AND counter regexp '%s'", u.ArrInt64ToStringMust(hostIds), counter)).Scan(&counters)
db.Graph.Table(ecHelp.TableName()).Select("distinct counter").Where("endpoint_id IN (?)", hostIds).Where("counter regexp ?", counter).Scan(&counters)
}
if len(counters) == 0 {
// 没有匹配到的继续执行,避免当grafana graph有多个查询时,其他正常的查询也无法渲染视图
Expand Down
5 changes: 3 additions & 2 deletions modules/api/app/controller/graph/graph_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func EndpointCounterRegexpQuery(c *gin.Context) {
if page > 1 {
offset = (page - 1) * limit
}
eidArray := []string{}
if eid == "" {
h.JSONR(c, http.StatusBadRequest, "eid is missing")
} else {
Expand All @@ -179,11 +180,11 @@ func EndpointCounterRegexpQuery(c *gin.Context) {
h.JSONR(c, http.StatusBadRequest, "input error, please check your input info.")
return
} else {
eids = fmt.Sprintf("(%s)", eids)
eidArray = strings.Split(eids, ",")
}

var counters []m.EndpointCounter
dt := db.Graph.Table("endpoint_counter").Select("endpoint_id, counter, step, type").Where(fmt.Sprintf("endpoint_id IN %s", eids))
dt := db.Graph.Table("endpoint_counter").Select("endpoint_id, counter, step, type").Where("endpoint_id IN (?)", eidArray)
if metricQuery != "" {
qs := strings.Split(metricQuery, " ")
if len(qs) > 0 {
Expand Down
Loading