Skip to content

Commit

Permalink
Merge pull request #263 from alisaifee/skip-init-on-disabled
Browse files Browse the repository at this point in the history
Skip initialization sequence when disabled
  • Loading branch information
alisaifee authored Aug 14, 2020
2 parents 23edbde + 601a665 commit d424229
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
7 changes: 6 additions & 1 deletion flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self.logger = logging.getLogger("flask-limiter")

self.enabled = enabled
self.initialized = False
self._default_limits = []
self._default_limits_per_method = default_limits_per_method
self._default_limits_exempt_when = default_limits_exempt_when
Expand Down Expand Up @@ -206,6 +207,9 @@ def init_app(self, app):
"""
config = app.config
self.enabled = config.setdefault(C.ENABLED, self.enabled)
if not self.enabled:
return

self._default_limits_per_method = config.setdefault(
C.DEFAULT_LIMITS_PER_METHOD, self._default_limits_per_method
)
Expand Down Expand Up @@ -307,6 +311,7 @@ def init_app(self, app):
app.after_request(self.__inject_headers)

app.extensions['limiter'] = self
self.initialized = True

def __configure_fallbacks(self, app, strategy):
config = app.config
Expand Down Expand Up @@ -496,7 +501,7 @@ def __check_request_limit(self, in_middleware=True):
)
if (
not request.endpoint
or not self.enabled
or not (self.enabled and self.initialized)
or view_func == current_app.send_static_file
or name in self._exempt_routes
or request.blueprint in self._blueprint_exempt
Expand Down
46 changes: 41 additions & 5 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,51 @@ def test_invalid_storage_string():
def test_constructor_arguments_over_config(redis_connection):
app = Flask(__name__)
app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry")
limiter = Limiter(
strategy='moving-window', key_func=get_remote_address
)
limiter = Limiter(strategy='moving-window', key_func=get_remote_address)
limiter.init_app(app)
app.config.setdefault(C.STORAGE_URL, "redis://localhost:36379")
assert type(limiter._limiter) == MovingWindowRateLimiter
limiter = Limiter(
storage_uri='memcached://localhost:31211',
key_func=get_remote_address
storage_uri='memcached://localhost:31211', key_func=get_remote_address
)
limiter.init_app(app)
assert type(limiter._storage) == MemcachedStorage


def test_invalid_config_with_disabled():
app = Flask(__name__)
app.config.setdefault(C.ENABLED, False)
app.config.setdefault(C.STORAGE_URL, "fubar://")

limiter = Limiter(app, default_limits=["1/hour"])

@app.route("/")
def root():
return "root"

@app.route("/explicit")
@limiter.limit("2/hour")
def explicit():
return "explicit"

with app.test_client() as client:
assert client.get("/").status_code == 200
assert client.get("/").status_code == 200
assert client.get("/explicit").status_code == 200
assert client.get("/explicit").status_code == 200
assert client.get("/explicit").status_code == 200


def test_uninitialized_limiter():
app = Flask(__name__)
limiter = Limiter(default_limits=["1/hour"])

@app.route("/")
@limiter.limit("2/hour")
def root():
return "root"

with app.test_client() as client:
assert client.get("/").status_code == 200
assert client.get("/").status_code == 200
assert client.get("/").status_code == 200

0 comments on commit d424229

Please sign in to comment.