diff --git a/apps/nsq_to_nsq/nsq_to_nsq.go b/apps/nsq_to_nsq/nsq_to_nsq.go index 4c2b1c149..b7a46f69e 100644 --- a/apps/nsq_to_nsq/nsq_to_nsq.go +++ b/apps/nsq_to_nsq/nsq_to_nsq.go @@ -4,17 +4,20 @@ package main import ( + "encoding/json" "errors" "flag" "fmt" "github.com/bitly/go-hostpool" "github.com/bitly/go-nsq" + "github.com/bitly/go-simplejson" "github.com/bitly/nsq/util" "log" "math" "os" "os/signal" "sort" + "strconv" "strings" "syscall" "time" @@ -29,17 +32,21 @@ var ( showVersion = flag.Bool("version", false, "print version string") topic = flag.String("topic", "", "nsq topic") - channel = flag.String("channel", "nsq_to_http", "nsq channel") + channel = flag.String("channel", "nsq_to_nsq", "nsq channel") destTopic = flag.String("destination-topic", "", "destination nsq topic") maxInFlight = flag.Int("max-in-flight", 200, "max number of messages to allow in flight") statusEvery = flag.Int("status-every", 250, "the # of requests between logging status (per destination), 0 disables") mode = flag.String("mode", "round-robin", "the upstream request mode options: round-robin (default), hostpool") - readerOpts = util.StringArray{} - nsqdTCPAddrs = util.StringArray{} - lookupdHTTPAddrs = util.StringArray{} - destNsqdTCPAddrs = util.StringArray{} + readerOpts = util.StringArray{} + nsqdTCPAddrs = util.StringArray{} + lookupdHTTPAddrs = util.StringArray{} + destNsqdTCPAddrs = util.StringArray{} + whitelistJsonFields = util.StringArray{} + + requireJsonField = flag.String("require-json-field", "", "for JSON messages: only pass messages that contain this field") + requireJsonValue = flag.String("require-json-value", "", "for JSON messages: only pass messages in which the required field has this value") // TODO: remove, deprecated maxBackoffDuration = flag.Duration("max-backoff-duration", 120*time.Second, "(deprecated) use --reader-opt=max_backoff_duration=X, the maximum backoff duration") @@ -51,6 +58,8 @@ func init() { flag.Var(&nsqdTCPAddrs, "nsqd-tcp-address", "nsqd TCP address (may be given multiple times)") flag.Var(&destNsqdTCPAddrs, "destination-nsqd-tcp-address", "destination nsqd TCP address (may be given multiple times)") flag.Var(&lookupdHTTPAddrs, "lookupd-http-address", "lookupd HTTP address (may be given multiple times)") + + flag.Var(&whitelistJsonFields, "whitelist-json-field", "for JSON messages: pass this field (may be given multiple times)") } type Durations []time.Duration @@ -67,6 +76,10 @@ func (s Durations) Less(i, j int) bool { return s[i] < s[j] } +func getRequeueDelay(m *nsq.Message) int { + return int(60 * time.Second * time.Duration(m.Attempts) / time.Millisecond) +} + type PublishHandler struct { addresses util.StringArray writers map[string]*nsq.Writer @@ -76,6 +89,10 @@ type PublishHandler struct { reqs Durations id int respChan chan *nsq.WriterTransaction + + requireJsonValueParsed bool + requireJsonValueIsNumber bool + requireJsonNumber float64 } func (ph *PublishHandler) responder() { @@ -108,8 +125,7 @@ func (ph *PublishHandler) responder() { } } - requeueDelay := int(60 * time.Second * time.Duration(msg.Attempts) / time.Millisecond) - respChan <- &nsq.FinishedMessage{msg.Id, requeueDelay, success} + respChan <- &nsq.FinishedMessage{msg.Id, getRequeueDelay(msg), success} if *statusEvery > 0 { duration := time.Now().Sub(startTime) @@ -135,8 +151,118 @@ func (ph *PublishHandler) responder() { } } +func (ph *PublishHandler) shouldPassMessage(jsonMsg *simplejson.Json) (bool, bool) { + pass := true + backoff := false + + if *requireJsonField == "" { + return pass, backoff + } + + if *requireJsonValue != "" && !ph.requireJsonValueParsed { + // cache conversion in case needed while filtering json + var err error + ph.requireJsonNumber, err = strconv.ParseFloat(*requireJsonValue, 64) + ph.requireJsonValueIsNumber = (err == nil) + ph.requireJsonValueParsed = true + } + + jsonVal, ok := jsonMsg.CheckGet(*requireJsonField) + if !ok { + pass = false + if *requireJsonValue != "" { + log.Printf("ERROR: missing field to check required value") + backoff = true + } + } else if *requireJsonValue != "" { + // if command-line argument can't convert to float, then it can't match a number + // if it can, also integers (up to 2^53 or so) can be compared as float64 + if strVal, err := jsonVal.String(); err == nil { + if strVal != *requireJsonValue { + pass = false + } + } else if ph.requireJsonValueIsNumber { + floatVal, err := jsonVal.Float64() + if err != nil || ph.requireJsonNumber != floatVal { + pass = false + } + } else { + // json value wasn't a plain string, and argument wasn't a number + // give up on comparisons of other types + pass = false + } + } + + return pass, backoff +} + +func filterMessage(jsonMsg *simplejson.Json, rawMsg []byte) ([]byte, error) { + if len(whitelistJsonFields) == 0 { + // no change + return rawMsg, nil + } + + msg, err := jsonMsg.Map() + if err != nil { + return nil, errors.New("json is not an object") + } + + newMsg := make(map[string]interface{}, len(whitelistJsonFields)) + + for _, key := range whitelistJsonFields { + value, ok := msg[key] + if ok { + // avoid printing int as float (go 1.0) + switch tvalue := value.(type) { + case float64: + ivalue := int64(tvalue) + if float64(ivalue) == tvalue { + newMsg[key] = ivalue + } else { + newMsg[key] = tvalue + } + default: + newMsg[key] = value + } + } + } + + newRawMsg, err := json.Marshal(newMsg) + if err != nil { + return nil, fmt.Errorf("unable to marshal filtered message %r", newMsg) + } + return newRawMsg, nil +} + func (ph *PublishHandler) HandleMessage(m *nsq.Message, respChan chan *nsq.FinishedMessage) { var err error + msgBody := m.Body + + if *requireJsonField != "" || len(whitelistJsonFields) > 0 { + var jsonMsg *simplejson.Json + jsonMsg, err = simplejson.NewJson(m.Body) + if err != nil { + log.Printf("ERROR: Unable to decode json: %s", m.Body) + respChan <- &nsq.FinishedMessage{m.Id, 0, true} + return + } + + if pass, backoff := ph.shouldPassMessage(jsonMsg); !pass { + if backoff { + respChan <- &nsq.FinishedMessage{m.Id, getRequeueDelay(m), false} + } else { + respChan <- &nsq.FinishedMessage{m.Id, 0, true} + } + return + } + + msgBody, err = filterMessage(jsonMsg, m.Body) + if err != nil { + log.Printf("ERROR: filterMessage() failed: %s", err) + respChan <- &nsq.FinishedMessage{m.Id, getRequeueDelay(m), false} + return + } + } startTime := time.Now() @@ -144,20 +270,19 @@ func (ph *PublishHandler) HandleMessage(m *nsq.Message, respChan chan *nsq.Finis case ModeRoundRobin: idx := ph.counter % uint64(len(ph.addresses)) writer := ph.writers[ph.addresses[idx]] - err = writer.PublishAsync(*destTopic, m.Body, ph.respChan, m, respChan, startTime) + err = writer.PublishAsync(*destTopic, msgBody, ph.respChan, m, respChan, startTime) ph.counter++ case ModeHostPool: hostPoolResponse := ph.hostPool.Get() writer := ph.writers[hostPoolResponse.Host()] - err = writer.PublishAsync(*destTopic, m.Body, ph.respChan, m, respChan, startTime, hostPoolResponse) + err = writer.PublishAsync(*destTopic, msgBody, ph.respChan, m, respChan, startTime, hostPoolResponse) if err != nil { hostPoolResponse.Mark(err) } } if err != nil { - requeueDelay := int(60 * time.Second * time.Duration(m.Attempts) / time.Millisecond) - respChan <- &nsq.FinishedMessage{m.Id, requeueDelay, false} + respChan <- &nsq.FinishedMessage{m.Id, getRequeueDelay(m), false} } }