diff --git a/cmd/start.go b/cmd/start.go index 8f4e6ca..1f796a7 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -37,6 +37,10 @@ var ( keyFile = viper.GetString("key_file") commonName = viper.GetString("common_name") + corsAllowedMethods = viper.GetStringSlice("cors_allowed_methods") + corsAllowedOrigins = viper.GetStringSlice("cors_allowed_origins") + corsAllowedHeaders = viper.GetStringSlice("cors_allowed_headers") + logLevel = viper.GetString("log_level") logFile = viper.GetString("log_file") logMaxSize = viper.GetInt("log_max_size") @@ -73,7 +77,7 @@ var ( return err } - grpcGateway, err := server.NewGRPCGateway(httpAddress, grpcAddress, certificateFile, keyFile, commonName, logger) + grpcGateway, err := server.NewGRPCGateway(httpAddress, grpcAddress, certificateFile, keyFile, commonName, corsAllowedMethods, corsAllowedOrigins, corsAllowedHeaders, logger) if err != nil { return err } @@ -186,6 +190,9 @@ func init() { startCmd.PersistentFlags().StringVar(&certificateFile, "certificate-file", "", "path to the client server TLS certificate file") startCmd.PersistentFlags().StringVar(&keyFile, "key-file", "", "path to the client server TLS key file") startCmd.PersistentFlags().StringVar(&commonName, "common-name", "", "certificate common name") + startCmd.PersistentFlags().StringSliceVar(&corsAllowedMethods, "cors-allowed-methods", []string{}, "CORS allowed methods (ex: GET,PUT,DELETE,POST)") + startCmd.PersistentFlags().StringSliceVar(&corsAllowedOrigins, "cors-allowed-origins", []string{}, "CORS allowed origins (ex: http://localhost:8080,http://localhost:80)") + startCmd.PersistentFlags().StringSliceVar(&corsAllowedHeaders, "cors-allowed-headers", []string{}, "CORS allowed headers (ex: content-type,x-some-key)") startCmd.PersistentFlags().StringVar(&logLevel, "log-level", "INFO", "log level") startCmd.PersistentFlags().StringVar(&logFile, "log-file", os.Stderr.Name(), "log file") startCmd.PersistentFlags().IntVar(&logMaxSize, "log-max-size", 500, "max size of a log file in megabytes") @@ -203,6 +210,9 @@ func init() { _ = viper.BindPFlag("certificate_file", startCmd.PersistentFlags().Lookup("certificate-file")) _ = viper.BindPFlag("key_file", startCmd.PersistentFlags().Lookup("key-file")) _ = viper.BindPFlag("common_name", startCmd.PersistentFlags().Lookup("common-name")) + _ = viper.BindPFlag("cors_allowed_methods", startCmd.PersistentFlags().Lookup("cors-allowed-methods")) + _ = viper.BindPFlag("cors_allowed_origins", startCmd.PersistentFlags().Lookup("cors-allowed-origins")) + _ = viper.BindPFlag("cors_allowed_headers", startCmd.PersistentFlags().Lookup("cors-allowed-headers")) _ = viper.BindPFlag("log_level", startCmd.PersistentFlags().Lookup("log-level")) _ = viper.BindPFlag("log_max_size", startCmd.PersistentFlags().Lookup("log-max-size")) _ = viper.BindPFlag("log_max_backups", startCmd.PersistentFlags().Lookup("log-max-backups")) diff --git a/cmd/variables.go b/cmd/variables.go index 8022742..0e0ef9b 100644 --- a/cmd/variables.go +++ b/cmd/variables.go @@ -1,22 +1,25 @@ package cmd var ( - configFile string - id string - raftAddress string - grpcAddress string - httpAddress string - dataDirectory string - peerGrpcAddress string - mappingFile string - certificateFile string - keyFile string - commonName string - file string - logLevel string - logFile string - logMaxSize int - logMaxBackups int - logMaxAge int - logCompress bool + configFile string + id string + raftAddress string + grpcAddress string + httpAddress string + dataDirectory string + peerGrpcAddress string + mappingFile string + certificateFile string + keyFile string + commonName string + corsAllowedMethods []string + corsAllowedOrigins []string + corsAllowedHeaders []string + file string + logLevel string + logFile string + logMaxSize int + logMaxBackups int + logMaxAge int + logCompress bool ) diff --git a/etc/blast.yaml b/etc/blast.yaml index ab362c1..a03e89a 100644 --- a/etc/blast.yaml +++ b/etc/blast.yaml @@ -1,3 +1,6 @@ +# +# General +# id: "node1" raft_address: ":7000" grpc_address: ":9000" @@ -5,9 +8,33 @@ http_address: ":8000" data_directory: "/tmp/blast/node1/data" #mapping_file: "./etc/blast_mapping.json" peer_grpc_address: "" + +# +# TLS +# #certificate_file: "./etc/blast-cert.pem" #key_file: "./etc/blast-key.pem" #common_name: "localhost" + +# +# CORS +# +#cors_allowed_methods: [ +# "GET", +# "PUT", +# "DELETE", +# "POST" +#] +#cors_allowed_origins: [ +# "http://localhost:8080" +#] +#cors_allowed_headers: [ +# "content-type" +#] + +# +# Logging +# log_level: "INFO" log_file: "" #log_max_size: 500 diff --git a/go.mod b/go.mod index a218f2c..0a3a8dd 100644 --- a/go.mod +++ b/go.mod @@ -23,8 +23,9 @@ require ( github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 // indirect github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 // indirect - github.com/gogo/protobuf v1.3.0 + github.com/gogo/protobuf v1.3.0 // indirect github.com/golang/protobuf v1.3.5 + github.com/gorilla/handlers v1.4.2 github.com/grpc-ecosystem/go-grpc-middleware v1.2.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/grpc-ecosystem/grpc-gateway v1.14.3 diff --git a/go.sum b/go.sum index 974344e..c714a35 100644 --- a/go.sum +++ b/go.sum @@ -138,6 +138,8 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/handlers v1.4.2 h1:0QniY0USkHQ1RGCLfKxeNHK9bkDHGRYGNDFBCS+YARg= +github.com/gorilla/handlers v1.4.2/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.2.0 h1:0IKlLyQ3Hs9nDaiK5cSHAGmcQEIC8l2Ts1u6x5Dfrqg= diff --git a/server/grpc_gateway.go b/server/grpc_gateway.go index c319fc0..c63572c 100644 --- a/server/grpc_gateway.go +++ b/server/grpc_gateway.go @@ -8,6 +8,7 @@ import ( "time" "github.com/golang/protobuf/proto" + "github.com/gorilla/handlers" "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/mosuka/blast/marshaler" "github.com/mosuka/blast/protobuf" @@ -41,10 +42,14 @@ type GRPCGateway struct { certificateFile string keyFile string + corsAllowedMethods []string + corsAllowedOrigins []string + corsAllowedHeaders []string + logger *zap.Logger } -func NewGRPCGateway(httpAddress string, grpcAddress string, certificateFile string, keyFile string, commonName string, logger *zap.Logger) (*GRPCGateway, error) { +func NewGRPCGateway(httpAddress string, grpcAddress string, certificateFile string, keyFile string, commonName string, corsAllowedMethods []string, corsAllowedOrigins []string, corsAllowedHeaders []string, logger *zap.Logger) (*GRPCGateway, error) { dialOpts := []grpc.DialOption{ grpc.WithDefaultCallOptions( grpc.MaxCallSendMsgSize(math.MaxInt64), @@ -90,25 +95,52 @@ func NewGRPCGateway(httpAddress string, grpcAddress string, certificateFile stri } return &GRPCGateway{ - httpAddress: httpAddress, - grpcAddress: grpcAddress, - listener: listener, - mux: mux, - cancel: cancel, - certificateFile: certificateFile, - keyFile: keyFile, - logger: logger, + httpAddress: httpAddress, + grpcAddress: grpcAddress, + listener: listener, + mux: mux, + cancel: cancel, + certificateFile: certificateFile, + keyFile: keyFile, + corsAllowedMethods: corsAllowedMethods, + corsAllowedOrigins: corsAllowedOrigins, + corsAllowedHeaders: corsAllowedHeaders, + logger: logger, }, nil } func (s *GRPCGateway) Start() error { + corsOpts := make([]handlers.CORSOption, 0) + + if s.corsAllowedMethods != nil && len(s.corsAllowedMethods) > 0 { + corsOpts = append(corsOpts, handlers.AllowedMethods(s.corsAllowedMethods)) + } + if s.corsAllowedOrigins != nil && len(s.corsAllowedOrigins) > 0 { + corsOpts = append(corsOpts, handlers.AllowedMethods(s.corsAllowedOrigins)) + } + if s.corsAllowedHeaders != nil && len(s.corsAllowedHeaders) > 0 { + corsOpts = append(corsOpts, handlers.AllowedMethods(s.corsAllowedHeaders)) + } + + corsMux := handlers.CORS( + corsOpts..., + )(s.mux) + if s.certificateFile == "" && s.keyFile == "" { go func() { - _ = http.Serve(s.listener, s.mux) + if len(corsOpts) > 0 { + _ = http.Serve(s.listener, corsMux) + } else { + _ = http.Serve(s.listener, s.mux) + } }() } else { go func() { - _ = http.ServeTLS(s.listener, s.mux, s.certificateFile, s.keyFile) + if len(corsOpts) > 0 { + _ = http.ServeTLS(s.listener, corsMux, s.certificateFile, s.keyFile) + } else { + _ = http.ServeTLS(s.listener, s.mux, s.certificateFile, s.keyFile) + } }() }