Skip to content

Commit

Permalink
[ISSUE #6968] fix grpc acl bug (#6969)
Browse files Browse the repository at this point in the history
* feat(acl): fix acl bug

Signed-off-by: lyx <1419360299@qq.com>

# Conflicts:
#	proxy/src/main/java/org/apache/rocketmq/proxy/grpc/GrpcServerBuilder.java

* add access test for two client

Signed-off-by: lyx <1419360299@qq.com>

* use specific acl config

Signed-off-by: lyx <1419360299@qq.com>

* Recovering unchange file

Signed-off-by: lyx <1419360299@qq.com>

* let test pass

Signed-off-by: lyx <1419360299@qq.com>

---------

Signed-off-by: lyx <1419360299@qq.com>
  • Loading branch information
lyx2000 authored Jul 15, 2023
1 parent 70a66ed commit 5914ff8
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
if (!request.hasGroup()) {
throw new AclException("Consumer heartbeat doesn't have group");
} else {
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
}
}
} else if (SendMessageRequest.getDescriptor().getFullName().equals(rpcFullName)) {
Expand All @@ -240,15 +240,15 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
accessResource.addResourceAndPerm(topic, Permission.PUB);
} else if (ReceiveMessageRequest.getDescriptor().getFullName().equals(rpcFullName)) {
ReceiveMessageRequest request = (ReceiveMessageRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getMessageQueue().getTopic(), Permission.SUB);
} else if (AckMessageRequest.getDescriptor().getFullName().equals(rpcFullName)) {
AckMessageRequest request = (AckMessageRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
} else if (ForwardMessageToDeadLetterQueueRequest.getDescriptor().getFullName().equals(rpcFullName)) {
ForwardMessageToDeadLetterQueueRequest request = (ForwardMessageToDeadLetterQueueRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
} else if (EndTransactionRequest.getDescriptor().getFullName().equals(rpcFullName)) {
EndTransactionRequest request = (EndTransactionRequest) messageV3;
Expand All @@ -264,7 +264,7 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
}
if (command.getSettings().hasSubscription()) {
Subscription subscription = command.getSettings().getSubscription();
accessResource.addResourceAndPerm(subscription.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(subscription.getGroup(), Permission.SUB);
for (SubscriptionEntry entry : subscription.getSubscriptionsList()) {
accessResource.addResourceAndPerm(entry.getTopic(), Permission.SUB);
}
Expand All @@ -275,17 +275,17 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
}
} else if (NotifyClientTerminationRequest.getDescriptor().getFullName().equals(rpcFullName)) {
NotifyClientTerminationRequest request = (NotifyClientTerminationRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
} else if (QueryRouteRequest.getDescriptor().getFullName().equals(rpcFullName)) {
QueryRouteRequest request = (QueryRouteRequest) messageV3;
accessResource.addResourceAndPerm(request.getTopic(), Permission.ANY);
} else if (QueryAssignmentRequest.getDescriptor().getFullName().equals(rpcFullName)) {
QueryAssignmentRequest request = (QueryAssignmentRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
} else if (ChangeInvisibleDurationRequest.getDescriptor().getFullName().equals(rpcFullName)) {
ChangeInvisibleDurationRequest request = (ChangeInvisibleDurationRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
}
} catch (Throwable t) {
Expand All @@ -299,6 +299,11 @@ private void addResourceAndPerm(Resource resource, byte permission) {
addResourceAndPerm(resourceName, permission);
}

private void addGroupResourceAndPerm(Resource resource, byte permission) {
String resourceName = NamespaceUtil.wrapNamespace(resource.getResourceNamespace(), resource.getName());
addResourceAndPerm(getRetryTopic(resourceName), permission);
}

public static PlainAccessResource build(PlainAccessConfig plainAccessConfig, RemoteAddressStrategy remoteAddressStrategy) {
PlainAccessResource plainAccessResource = new PlainAccessResource();
plainAccessResource.setAccessKey(plainAccessConfig.getAccessKey());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.acl;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.rocketmq.acl.common.AclClientRPCHook;
import org.apache.rocketmq.acl.common.AclException;
import org.apache.rocketmq.acl.common.SessionCredentials;
import org.apache.rocketmq.acl.plain.AclTestHelper;
import org.apache.rocketmq.acl.plain.PlainAccessResource;
import org.apache.rocketmq.acl.plain.PlainAccessValidator;
import org.apache.rocketmq.remoting.exception.RemotingCommandException;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.apache.rocketmq.remoting.protocol.RequestCode;
import org.apache.rocketmq.remoting.protocol.header.PullMessageRequestHeader;
import org.apache.rocketmq.remoting.protocol.header.SendMessageRequestHeader;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class RemotingClientAccessTest {

private PlainAccessValidator plainAccessValidator;
private AclClientRPCHook aclClient;
private SessionCredentials sessionCredentials;

private File confHome;

private String clientAddress = "10.7.1.3";

@Before
public void init() throws IOException {
String folder = "access_acl_conf";
confHome = AclTestHelper.copyResources(folder, true);
System.setProperty("rocketmq.home.dir", confHome.getAbsolutePath());
System.setProperty("rocketmq.acl.plain.file", "/access_acl_conf/acl/plain_acl.yml".replace("/", File.separator));

plainAccessValidator = new PlainAccessValidator();
sessionCredentials = new SessionCredentials();
sessionCredentials.setAccessKey("rocketmq3");
sessionCredentials.setSecretKey("12345678");
aclClient = new AclClientRPCHook(sessionCredentials);
}

@After
public void cleanUp() {
AclTestHelper.recursiveDelete(confHome);
}

@Test(expected = AclException.class)
public void testProduceDenyTopic() {
SendMessageRequestHeader messageRequestHeader = new SendMessageRequestHeader();
messageRequestHeader.setTopic("topicD");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, messageRequestHeader);
aclClient.doBeforeRequest(clientAddress, remotingCommand);

ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), clientAddress);
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

@Test
public void testProduceAuthorizedTopic() {
SendMessageRequestHeader messageRequestHeader = new SendMessageRequestHeader();
messageRequestHeader.setTopic("topicA");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, messageRequestHeader);
aclClient.doBeforeRequest(clientAddress, remotingCommand);

ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), clientAddress);
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}


@Test(expected = AclException.class)
public void testConsumeDenyTopic() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicD");
pullMessageRequestHeader.setConsumerGroup("groupB");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}

}

@Test
public void testConsumeAuthorizedTopic() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicB");
pullMessageRequestHeader.setConsumerGroup("groupB");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

@Test(expected = AclException.class)
public void testConsumeInDeniedGroup() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicB");
pullMessageRequestHeader.setConsumerGroup("groupD");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

@Test
public void testConsumeInAuthorizedGroup() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicB");
pullMessageRequestHeader.setConsumerGroup("groupB");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

}
31 changes: 31 additions & 0 deletions acl/src/test/resources/access_acl_conf/acl/plain_acl.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

accounts:
- accessKey: rocketmq3
secretKey: 12345678
admin: false
defaultTopicPerm: DENY
defaultGroupPerm: DENY
topicPerms:
- topicA=PUB
- topicB=SUB
- topicC=PUB|SUB
- topicD=DENY
groupPerms:
- groupB=SUB
- groupC=PUB|SUB
- groupD=DENY

1 change: 0 additions & 1 deletion acl/src/test/resources/conf/acl/plain_acl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,3 @@ accounts:
whiteRemoteAddress: 192.168.1.*
# if it is admin, it could access all resources
admin: true

1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
<awaitility.version>4.1.0</awaitility.version>
<truth.version>0.30</truth.version>
<s3mock-junit4.version>2.11.0</s3mock-junit4.version>
<rocketmq-client-java.version>5.0.5</rocketmq-client-java.version>

<!-- Build plugin dependencies -->
<versions-maven-plugin.version>2.2</versions-maven-plugin.version>
Expand Down
17 changes: 15 additions & 2 deletions proxy/src/main/java/org/apache/rocketmq/proxy/ProxyStartup.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.acl.AccessValidator;
import org.apache.rocketmq.acl.plain.PlainAccessValidator;
import org.apache.rocketmq.broker.BrokerController;
import org.apache.rocketmq.broker.BrokerStartup;
import org.apache.rocketmq.common.MixAll;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.thread.ThreadPoolMonitor;
import org.apache.rocketmq.common.utils.ServiceProvider;
import org.apache.rocketmq.logging.org.slf4j.Logger;
import org.apache.rocketmq.logging.org.slf4j.LoggerFactory;
import org.apache.rocketmq.common.utils.AbstractStartAndShutdown;
Expand Down Expand Up @@ -75,16 +78,17 @@ public static void main(String[] args) {

MessagingProcessor messagingProcessor = createMessagingProcessor();

List<AccessValidator> accessValidators = loadAccessValidators();
// create grpcServer
GrpcServer grpcServer = GrpcServerBuilder.newBuilder(executor, ConfigurationManager.getProxyConfig().getGrpcServerPort())
.addService(createServiceProcessor(messagingProcessor))
.addService(ChannelzService.newInstance(100))
.addService(ProtoReflectionService.newInstance())
.configInterceptor()
.configInterceptor(accessValidators)
.build();
PROXY_START_AND_SHUTDOWN.appendStartAndShutdown(grpcServer);

RemotingProtocolServer remotingServer = new RemotingProtocolServer(messagingProcessor);
RemotingProtocolServer remotingServer = new RemotingProtocolServer(messagingProcessor, accessValidators);
PROXY_START_AND_SHUTDOWN.appendStartAndShutdown(remotingServer);

// start servers one by one.
Expand All @@ -109,6 +113,15 @@ public static void main(String[] args) {
log.info(new Date() + " rocketmq-proxy startup successfully");
}

protected static List<AccessValidator> loadAccessValidators() {
List<AccessValidator> accessValidators = ServiceProvider.load(AccessValidator.class);
if (accessValidators.isEmpty()) {
log.info("ServiceProvider loaded no AccessValidator, using default org.apache.rocketmq.acl.plain.PlainAccessValidator");
accessValidators.add(new PlainAccessValidator());
}
return accessValidators;
}

protected static void initConfiguration(CommandLineArgument commandLineArgument) throws Exception {
if (StringUtils.isNotBlank(commandLineArgument.getProxyConfigPath())) {
System.setProperty(Configuration.CONFIG_PATH_PROPERTY, commandLineArgument.getProxyConfigPath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.rocketmq.acl.AccessValidator;
import org.apache.rocketmq.acl.plain.PlainAccessValidator;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.utils.ServiceProvider;
import org.apache.rocketmq.logging.org.slf4j.Logger;
import org.apache.rocketmq.logging.org.slf4j.LoggerFactory;
import org.apache.rocketmq.proxy.config.ConfigurationManager;
Expand Down Expand Up @@ -98,14 +96,8 @@ public GrpcServer build() {
return new GrpcServer(this.serverBuilder.build());
}

public GrpcServerBuilder configInterceptor() {
public GrpcServerBuilder configInterceptor(List<AccessValidator> accessValidators) {
// grpc interceptors, including acl, logging etc.
List<AccessValidator> accessValidators = ServiceProvider.load(AccessValidator.class);
if (accessValidators.isEmpty()) {
log.info("ServiceProvider loaded no AccessValidator, using default org.apache.rocketmq.acl.plain.PlainAccessValidator");
accessValidators.add(new PlainAccessValidator());
}

this.serverBuilder.intercept(new AuthenticationInterceptor(accessValidators));

this.serverBuilder
Expand Down
Loading

0 comments on commit 5914ff8

Please sign in to comment.