Skip to content

Commit

Permalink
rpcsrv: carefully store Oracle service
Browse files Browse the repository at this point in the history
And simplify atomic service value stored by RPC server. Oracle service can
either be an untyped nil or be the proper non-nil *oracle.Oracle.
Otherwise `submitoracleresponse` RPC handler doesn't work properly.

Signed-off-by: Anna Shaleva <shaleva.ann@nspcc.ru>
  • Loading branch information
AnnaShaleva committed Aug 9, 2023
1 parent 6c1240d commit 8ae8630
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
9 changes: 8 additions & 1 deletion cli/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,14 @@ func resetDB(ctx *cli.Context) error {
return nil
}

func mkOracle(config config.OracleConfiguration, magic netmode.Magic, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (*oracle.Oracle, error) {
// OracleService is an interface representing Oracle service with network.Service
// capabilities and ability to submit oracle responses.
type OracleService interface {
rpcsrv.OracleHandler
network.Service
}

func mkOracle(config config.OracleConfiguration, magic netmode.Magic, chain *core.Blockchain, serv *network.Server, log *zap.Logger) (OracleService, error) {
if !config.Enabled {
return nil, nil
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/services/rpcsrv/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ var invalidBlockHeightError = func(index int, height int) *neorpc.Error {
return neorpc.NewRPCError("Invalid block height", fmt.Sprintf("param at index %d should be greater than or equal to 0 and less then or equal to current block height, got: %d", index, height))
}

// New creates a new Server struct.
// New creates a new Server struct. Pay attention that orc is expected to be either
// untyped nil or non-nil structure implementing OracleHandler interface.
func New(chain Ledger, conf config.RPC, coreServer *network.Server,
orc OracleHandler, log *zap.Logger, errChan chan<- error) Server {
addrs := conf.GetAddresses()
Expand Down Expand Up @@ -293,7 +294,7 @@ func New(chain Ledger, conf config.RPC, coreServer *network.Server,
}
var oracleWrapped = new(atomic.Value)
if orc != nil {
oracleWrapped.Store(&orc)
oracleWrapped.Store(orc)
}
var wsOriginChecker func(*http.Request) bool
if conf.EnableCORSWorkaround {
Expand Down Expand Up @@ -445,7 +446,7 @@ func (s *Server) Shutdown() {

// SetOracleHandler allows to update oracle handler used by the Server.
func (s *Server) SetOracleHandler(orc OracleHandler) {
s.oracle.Store(&orc)
s.oracle.Store(orc)
}

func (s *Server) handleHTTPRequest(w http.ResponseWriter, httpRequest *http.Request) {
Expand Down Expand Up @@ -2461,10 +2462,11 @@ func getRelayResult(err error, hash util.Uint256) (any, *neorpc.Error) {
}

func (s *Server) submitOracleResponse(ps params.Params) (any, *neorpc.Error) {
oracle := s.oracle.Load().(*OracleHandler)
if oracle == nil || *oracle == nil {
oraclePtr := s.oracle.Load()
if oraclePtr == nil {
return nil, neorpc.NewRPCError("Oracle is not enabled", "")
}
oracle := oraclePtr.(OracleHandler)
var pub *keys.PublicKey
pubBytes, err := ps.Value(0).GetBytesBase64()
if err == nil {
Expand All @@ -2489,7 +2491,7 @@ func (s *Server) submitOracleResponse(ps params.Params) (any, *neorpc.Error) {
if !pub.Verify(msgSig, hash.Sha256(data).BytesBE()) {
return nil, neorpc.NewRPCError("Invalid request signature", "")
}
(*oracle).AddResponse(pub, uint64(reqID), txSig)
(oracle).AddResponse(pub, uint64(reqID), txSig)
return json.RawMessage([]byte("{}")), nil
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/services/rpcsrv/server_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (
notaryPass = "one"
)

func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disableIteratorSessions bool) (*core.Blockchain, *oracle.Oracle, config.Config, *zap.Logger) {
func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disableIteratorSessions bool) (*core.Blockchain, OracleHandler, config.Config, *zap.Logger) {
return getUnitTestChainWithCustomConfig(t, enableOracle, enableNotary, func(cfg *config.Config) {
if disableIteratorSessions {
cfg.ApplicationConfiguration.RPC.SessionEnabled = false
Expand All @@ -56,7 +56,7 @@ func getUnitTestChain(t testing.TB, enableOracle bool, enableNotary bool, disabl
}
})
}
func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNotary bool, customCfg func(configuration *config.Config)) (*core.Blockchain, *oracle.Oracle, config.Config, *zap.Logger) {
func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNotary bool, customCfg func(configuration *config.Config)) (*core.Blockchain, OracleHandler, config.Config, *zap.Logger) {
net := netmode.UnitTestNet
configPath := "../../../config"
cfg, err := config.Load(configPath, net)
Expand All @@ -70,7 +70,7 @@ func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNot
chain, err := core.NewBlockchain(memoryStore, cfg.Blockchain(), logger)
require.NoError(t, err, "could not create chain")

var orc *oracle.Oracle
var orc OracleHandler
if enableOracle {
orc, err = oracle.NewOracle(oracle.Config{
Log: logger,
Expand All @@ -79,7 +79,7 @@ func getUnitTestChainWithCustomConfig(t testing.TB, enableOracle bool, enableNot
Chain: chain,
})
require.NoError(t, err)
chain.SetOracle(orc)
chain.SetOracle(orc.(*oracle.Oracle))
}

go chain.Run()
Expand Down Expand Up @@ -115,7 +115,7 @@ func initClearServerWithServices(t testing.TB, needOracle bool, needNotary bool,
return wrapUnitTestChain(t, chain, orc, cfg, logger)
}

func wrapUnitTestChain(t testing.TB, chain *core.Blockchain, orc *oracle.Oracle, cfg config.Config, logger *zap.Logger) (*core.Blockchain, *Server, *httptest.Server) {
func wrapUnitTestChain(t testing.TB, chain *core.Blockchain, orc OracleHandler, cfg config.Config, logger *zap.Logger) (*core.Blockchain, *Server, *httptest.Server) {
serverConfig, err := network.NewServerConfig(cfg)
require.NoError(t, err)
serverConfig.UserAgent = fmt.Sprintf(config.UserAgentFormat, "0.98.6-test")
Expand Down

0 comments on commit 8ae8630

Please sign in to comment.