diff --git a/packages/react-router-dom/modules/Link.js b/packages/react-router-dom/modules/Link.js
index a99b869bbd..03614362d7 100644
--- a/packages/react-router-dom/modules/Link.js
+++ b/packages/react-router-dom/modules/Link.js
@@ -4,71 +4,89 @@ import PropTypes from "prop-types";
import invariant from "tiny-invariant";
import { resolveToLocation, normalizeToLocation } from "./utils/locationUtils";
+// React 15 compat
+let { forwardRef } = React;
+if (typeof forwardRef === "undefined") {
+ forwardRef = C => C;
+}
+
function isModifiedEvent(event) {
return !!(event.metaKey || event.altKey || event.ctrlKey || event.shiftKey);
}
-function LinkAnchor({ innerRef, navigate, onClick, ...rest }) {
- const { target } = rest;
-
- return (
- {
- try {
- if (onClick) onClick(event);
- } catch (ex) {
- event.preventDefault();
- throw ex;
- }
-
- if (
- !event.defaultPrevented && // onClick prevented default
- event.button === 0 && // ignore everything but left clicks
- (!target || target === "_self") && // let browser handle "target=_blank" etc.
- !isModifiedEvent(event) // ignore clicks with modifier keys
- ) {
- event.preventDefault();
- navigate();
- }
- }}
- />
- );
+const LinkAnchor = forwardRef(
+ ({ innerRef, navigate, onClick, ...rest }, forwardedRef) => {
+ const { target } = rest;
+
+ return (
+ {
+ try {
+ if (onClick) onClick(event);
+ } catch (ex) {
+ event.preventDefault();
+ throw ex;
+ }
+
+ if (
+ !event.defaultPrevented && // onClick prevented default
+ event.button === 0 && // ignore everything but left clicks
+ (!target || target === "_self") && // let browser handle "target=_blank" etc.
+ !isModifiedEvent(event) // ignore clicks with modifier keys
+ ) {
+ event.preventDefault();
+ navigate();
+ }
+ }}
+ />
+ );
+ }
+);
+
+if (__DEV__) {
+ LinkAnchor.displayName = "LinkAnchor";
}
/**
* The public API for rendering a history-aware .
*/
-function Link({ component = LinkAnchor, replace, to, ...rest }) {
- return (
-
- {context => {
- invariant(context, "You should not use outside a ");
+const Link = forwardRef(
+ (
+ { component = LinkAnchor, replace, to, innerRef, ...rest },
+ forwardedRef
+ ) => {
+ return (
+
+ {context => {
+ invariant(context, "You should not use outside a ");
- const { history } = context;
+ const { history } = context;
- const location = normalizeToLocation(
- resolveToLocation(to, context.location),
- context.location
- );
+ const location = normalizeToLocation(
+ resolveToLocation(to, context.location),
+ context.location
+ );
- const href = location ? history.createHref(location) : "";
+ const href = location ? history.createHref(location) : "";
- return React.createElement(component, {
- ...rest,
- href,
- navigate() {
- const location = resolveToLocation(to, context.location);
- const method = replace ? history.replace : history.push;
+ return React.createElement(component, {
+ ...rest,
+ ref: forwardedRef || innerRef,
+ href,
+ navigate() {
+ const location = resolveToLocation(to, context.location);
+ const method = replace ? history.replace : history.push;
- method(location);
- }
- });
- }}
-
- );
-}
+ method(location);
+ }
+ });
+ }}
+
+ );
+ }
+);
if (__DEV__) {
const toType = PropTypes.oneOfType([
@@ -82,6 +100,8 @@ if (__DEV__) {
PropTypes.shape({ current: PropTypes.any })
]);
+ Link.displayName = "Link";
+
Link.propTypes = {
innerRef: refType,
onClick: PropTypes.func,
diff --git a/packages/react-router-dom/modules/NavLink.js b/packages/react-router-dom/modules/NavLink.js
index 25e0c85caa..c1f98d1fdb 100644
--- a/packages/react-router-dom/modules/NavLink.js
+++ b/packages/react-router-dom/modules/NavLink.js
@@ -5,6 +5,12 @@ import invariant from "tiny-invariant";
import Link from "./Link";
import { resolveToLocation, normalizeToLocation } from "./utils/locationUtils";
+// React 15 compat
+let { forwardRef } = React;
+if (typeof forwardRef === "undefined") {
+ forwardRef = C => C;
+}
+
function joinClassnames(...classnames) {
return classnames.filter(i => i).join(" ");
}
@@ -12,61 +18,74 @@ function joinClassnames(...classnames) {
/**
* A wrapper that knows if it's "active" or not.
*/
-function NavLink({
- "aria-current": ariaCurrent = "page",
- activeClassName = "active",
- activeStyle,
- className: classNameProp,
- exact,
- isActive: isActiveProp,
- location: locationProp,
- strict,
- style: styleProp,
- to,
- ...rest
-}) {
- return (
-
- {context => {
- invariant(context, "You should not use outside a ");
+const NavLink = forwardRef(
+ (
+ {
+ "aria-current": ariaCurrent = "page",
+ activeClassName = "active",
+ activeStyle,
+ className: classNameProp,
+ exact,
+ isActive: isActiveProp,
+ location: locationProp,
+ strict,
+ style: styleProp,
+ to,
+ innerRef,
+ ...rest
+ },
+ forwardedRef
+ ) => {
+ return (
+
+ {context => {
+ invariant(context, "You should not use outside a ");
- const currentLocation = locationProp || context.location;
- const toLocation = normalizeToLocation(
- resolveToLocation(to, currentLocation),
- currentLocation
- );
- const { pathname: path } = toLocation;
- // Regex taken from: https://github.com/pillarjs/path-to-regexp/blob/master/index.js#L202
- const escapedPath =
- path && path.replace(/([.+*?=^!:${}()[\]|/\\])/g, "\\$1");
+ const currentLocation = locationProp || context.location;
+ const toLocation = normalizeToLocation(
+ resolveToLocation(to, currentLocation),
+ currentLocation
+ );
+ const { pathname: path } = toLocation;
+ // Regex taken from: https://github.com/pillarjs/path-to-regexp/blob/master/index.js#L202
+ const escapedPath =
+ path && path.replace(/([.+*?=^!:${}()[\]|/\\])/g, "\\$1");
- const match = escapedPath
- ? matchPath(currentLocation.pathname, { path: escapedPath, exact, strict })
- : null;
- const isActive = !!(isActiveProp
- ? isActiveProp(match, currentLocation)
- : match);
+ const match = escapedPath
+ ? matchPath(currentLocation.pathname, {
+ path: escapedPath,
+ exact,
+ strict
+ })
+ : null;
+ const isActive = !!(isActiveProp
+ ? isActiveProp(match, currentLocation)
+ : match);
- const className = isActive
- ? joinClassnames(classNameProp, activeClassName)
- : classNameProp;
- const style = isActive ? { ...styleProp, ...activeStyle } : styleProp;
+ const className = isActive
+ ? joinClassnames(classNameProp, activeClassName)
+ : classNameProp;
+ const style = isActive ? { ...styleProp, ...activeStyle } : styleProp;
- return (
-
- );
- }}
-
- );
-}
+ return (
+
+ );
+ }}
+
+ );
+ }
+);
if (__DEV__) {
+ NavLink.displayName = "NavLink";
+
const ariaCurrentType = PropTypes.oneOf([
"page",
"step",
diff --git a/packages/react-router-dom/modules/__tests__/Link-test.js b/packages/react-router-dom/modules/__tests__/Link-test.js
index c2d71bb81c..df67b6ec8f 100644
--- a/packages/react-router-dom/modules/__tests__/Link-test.js
+++ b/packages/react-router-dom/modules/__tests__/Link-test.js
@@ -107,7 +107,26 @@ describe("A ", () => {
});
});
- it("exposes its ref via an innerRef callbar prop", () => {
+ it("forwards a ref", () => {
+ let refNode;
+ function refCallback(n) {
+ refNode = n;
+ }
+
+ renderStrict(
+
+
+ link
+
+ ,
+ node
+ );
+
+ expect(refNode).not.toBe(undefined);
+ expect(refNode.tagName).toEqual("A");
+ });
+
+ it("exposes its ref via an innerRef callback prop", () => {
let refNode;
function refCallback(n) {
refNode = n;
@@ -126,6 +145,31 @@ describe("A ", () => {
expect(refNode.tagName).toEqual("A");
});
+ it("prefers forwardRef over innerRef", () => {
+ let refNode;
+ function refCallback(n) {
+ refNode = n;
+ }
+
+ renderStrict(
+
+ {
+ throw new Error("wrong ref, champ");
+ }}
+ >
+ link
+
+ ,
+ node
+ );
+
+ expect(refNode).not.toBe(undefined);
+ expect(refNode.tagName).toEqual("A");
+ });
+
it("uses a custom component prop", () => {
let linkProps;
function MyComponent(p) {
diff --git a/packages/react-router-dom/modules/__tests__/NavLink-test.js b/packages/react-router-dom/modules/__tests__/NavLink-test.js
index f06987857c..2e5b1971c1 100644
--- a/packages/react-router-dom/modules/__tests__/NavLink-test.js
+++ b/packages/react-router-dom/modules/__tests__/NavLink-test.js
@@ -11,6 +11,25 @@ describe("A ", () => {
ReactDOM.unmountComponentAtNode(node);
});
+ it("forwards a ref", () => {
+ let refNode;
+ function refCallback(n) {
+ refNode = n;
+ }
+
+ renderStrict(
+
+
+ link
+
+ ,
+ node
+ );
+
+ expect(refNode).not.toBe(undefined);
+ expect(refNode.tagName).toEqual("A");
+ });
+
describe("when active", () => {
it("applies its default activeClassName", () => {
renderStrict(
@@ -490,7 +509,11 @@ describe("A ", () => {
it("overrides the current location for isActive", () => {
renderStrict(
- location.pathname === '/pasta'} location={{ pathname: "/pasta" }}>
+ location.pathname === "/pasta"}
+ location={{ pathname: "/pasta" }}
+ >
Pasta!
,