Skip to content

Commit

Permalink
SmbClient cleanup when connection is getting closed serverList not cl…
Browse files Browse the repository at this point in the history
…eared (#719)

* //Azure SMB drive is locally redundant which means multiple instances exist and next Connection request it might be different one

* Add test to prove ServerList is cleared on last connection close()

- Introduced testing classes using JUnit5

Signed-off-by: Jeroen van Erp <jeroen@hierynomus.com>

* Codacy styling fixes

Signed-off-by: Jeroen van Erp <jeroen@hierynomus.com>

---------

Signed-off-by: Jeroen van Erp <jeroen@hierynomus.com>
Co-authored-by: Jozef Dropco <Jozef_Dropco@swissre.com>
Co-authored-by: Jeroen van Erp <jeroen@hierynomus.com>
  • Loading branch information
3 people committed Aug 16, 2023
1 parent 44f7401 commit 8b921f0
Show file tree
Hide file tree
Showing 20 changed files with 526 additions and 244 deletions.
13 changes: 1 addition & 12 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ testing {
configureEach {
useJUnitJupiter()
dependencies {
implementation project()
implementation SLF4J_API
// implementation GROOVY_ALL
runtimeOnly CGLIB_NODEP
Expand Down Expand Up @@ -119,19 +120,7 @@ testing {
}
}

test {
sources {
groovy {
srcDirs = ['src/test/groovy']
}
}
}

integrationTest(JvmTestSuite) {
dependencies {
implementation project()
}

sources {
java {
srcDirs = ['src/it/java']
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/hierynomus/mssmb2/SMB2Error.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class SMB2Error {

private List<SMB2ErrorData> errorData = new ArrayList<>();

SMB2Error() {
public SMB2Error() {
}

SMB2Error read(SMB2PacketHeader header, SMBBuffer buffer) throws Buffer.BufferException {
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/com/hierynomus/mssmb2/SMB2Packet.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class SMB2Packet extends SMBPacket<SMB2PacketData, SMB2PacketHeader> {
protected int structureSize;
private SMB2Error error;

protected SMB2Packet() {
public SMB2Packet() {
super(new SMB2PacketHeader());
}

Expand Down Expand Up @@ -144,6 +144,10 @@ public SMB2Error getError() {
return error;
}

public void setError(SMB2Error error) {
this.error = error;
}

/**
* Method that can be overridden by Packet Wrappers to ensure that the original (typed) packet is obtainable.
*
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/com/hierynomus/mssmb2/SMB2PacketHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public void setMessageId(long messageId) {
this.messageId = messageId;
}

void setMessageType(SMB2MessageCommandCode messageType) {
public void setMessageType(SMB2MessageCommandCode messageType) {
this.message = messageType;
}

Expand Down Expand Up @@ -151,10 +151,18 @@ public void setCreditRequest(int creditRequest) {
this.creditRequest = creditRequest;
}

public int getCreditRequest() {
return creditRequest;
}

public int getCreditResponse() {
return creditResponse;
}

public void setCreditResponse(int creditResponse) {
this.creditResponse = creditResponse;
}

public void setAsyncId(long asyncId) {
this.asyncId = asyncId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,16 @@ public FileTime getServerStartTime() {
public List<SMB2NegotiateContext> getNegotiateContextList() {
return negotiateContextList;
}

public void setDialect(SMB2Dialect dialect) {
this.dialect = dialect;
}

public void setSystemTime(FileTime systemTime) {
this.systemTime = systemTime;
}

public void setServerGuid(UUID serverGuid) {
this.serverGuid = serverGuid;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ public Set<SMB2SessionFlags> getSessionFlags() {
return sessionFlags;
}

public void setSessionFlags(Set<SMB2SessionFlags> sessionFlags) {
this.sessionFlags = sessionFlags;
}

public void setPreviousSessionId(long previousSessionId) {
this.previousSessionId = previousSessionId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ protected void readMessage(SMBBuffer buffer) throws Buffer.BufferException {
maximalAccess = toEnumSet(buffer.readUInt32(), AccessMask.class); // MaximalAccess (4 bytes)
}

public void setShareType(byte shareType) {
this.shareType = shareType;
}

/**
* Whether the ShareType returned is SMB2_SHARE_TYPE_DISK (0x01)
*
Expand Down Expand Up @@ -79,10 +83,18 @@ public Set<SMB2ShareFlags> getShareFlags() {
return shareFlags;
}

public void setShareFlags(Set<SMB2ShareFlags> shareFlags) {
this.shareFlags = shareFlags;
}

public Set<SMB2ShareCapabilities> getCapabilities() {
return capabilities;
}

public void setCapabilities(Set<SMB2ShareCapabilities> capabilities) {
this.capabilities = capabilities;
}

public Set<AccessMask> getMaximalAccess() {
return maximalAccess;
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/hierynomus/smbj/SMBClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ private void connectionClosed(ConnectionClosed event) {
synchronized (this) {
String hostPort = event.getHostname() + ":" + event.getPort();
connectionTable.remove(hostPort);
serverList.unregister(event.getHostname());
logger.debug("Connection to << {} >> closed", hostPort);
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/com/hierynomus/smbj/connection/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public class Connection extends Pooled<Connection> implements Closeable, PacketR
private SessionTable preauthSessionTable = new SessionTable();
OutstandingRequests outstandingRequests = new OutstandingRequests();
SequenceWindow sequenceWindow;
private SMB2MessageConverter smb2Converter = new SMB2MessageConverter();
private SMB2MessageConverter messageConverter = new SMB2MessageConverter();
private PathResolver pathResolver;

private final SMBClient client;
Expand Down Expand Up @@ -119,7 +119,7 @@ private void init() {
new SMB2SignatureVerificationPacketHandler(sessionTable, signatory).setNext(
new SMB2CreditGrantingPacketHandler(sequenceWindow).setNext(
new SMB2AsyncResponsePacketHandler(outstandingRequests).setNext(
new SMB2ProcessResponsePacketHandler(smb2Converter, outstandingRequests).setNext(
new SMB2ProcessResponsePacketHandler(messageConverter, outstandingRequests).setNext(
new SMB1PacketHandler().setNext(new DeadLetterPacketHandler()))))))));
}

Expand Down Expand Up @@ -382,4 +382,8 @@ SessionTable getSessionTable() {
SessionTable getPreauthSessionTable() {
return preauthSessionTable;
}

public void setMessageConverter(SMB2MessageConverter smb2Converter) {
this.messageConverter = smb2Converter;
}
}
8 changes: 4 additions & 4 deletions src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
*/
package com.hierynomus.smbj

import com.hierynomus.smbj.connection.BasicPacketProcessor
import com.hierynomus.smbj.connection.StubTransportLayerFactory
import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor
import com.hierynomus.smbj.testing.StubTransportLayerFactory
import spock.lang.Specification

class SMBClientSpec extends Specification {

def processor = new BasicPacketProcessor({ req -> null })
def config = SmbConfig.builder().withTransportLayerFactory(new StubTransportLayerFactory(processor.&processPacket)).build()
def processor = new DefaultPacketProcessor()
def config = SmbConfig.builder().withTransportLayerFactory(new StubTransportLayerFactory(processor)).build()

def "should return same connection for same host/port combo"() {
given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,22 @@ import com.hierynomus.smbj.event.SMBEvent
import com.hierynomus.smbj.event.SMBEventBus
import com.hierynomus.smbj.event.SessionLoggedOff
import com.hierynomus.protocol.transport.TransportException
import com.hierynomus.smbj.testing.PacketProcessor.NoOpPacketProcessor
import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor
import com.hierynomus.smbj.testing.StubAuthenticator
import com.hierynomus.smbj.testing.StubTransportLayerFactory
import net.engio.mbassy.listener.Handler
import spock.lang.Specification

class ConnectionSpec extends Specification {

def bus = new SMBEventBus()
def packetProcessor = { req -> null }
def packetProcessor = new NoOpPacketProcessor()
def config = smbConfig(packetProcessor)

private SmbConfig smbConfig(packetProcessor) {
SmbConfig.builder()
.withTransportLayerFactory(new StubTransportLayerFactory(new BasicPacketProcessor(packetProcessor).&processPacket))
.withTransportLayerFactory(new StubTransportLayerFactory(new DefaultPacketProcessor().wrap(packetProcessor)))
.withAuthenticators(new StubAuthenticator.Factory())
.build()
}
Expand Down Expand Up @@ -152,41 +156,6 @@ class ConnectionSpec extends Specification {
!(conn.pathResolver instanceof DFSPathResolver)
}

def "should remove server from serverlist if identification changed"() {
given:
def sent = false
config = smbConfig({ req ->
req = req.packet
if (!sent && req instanceof SMB2NegotiateRequest) {
sent = true
def response = new SMB2NegotiateResponse()
response.header.message = SMB2MessageCommandCode.SMB2_NEGOTIATE
response.header.statusCode = NtStatus.STATUS_SUCCESS.value
response.dialect = SMB2Dialect.SMB_2_1
response.systemTime = FileTime.now();
response.serverGuid = UUID.fromString("ffeeddcc-bbaa-9988-7766-554433221100")
return response
}
})
client = new SMBClient(config)

when:
def conn = client.connect("foo")
conn.close()

conn = client.connect("foo")

then:
thrown(TransportException)

when:
client.getServerList().unregister("foo")
conn = client.connect("foo")

then:
noExceptionThrown()
}

def "should add DFS path resolver if server supports DFS"() {
given:
config = smbConfig({ req ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ import com.hierynomus.smbj.SMBClient
import com.hierynomus.smbj.SmbConfig
import com.hierynomus.smbj.event.ConnectionClosed
import com.hierynomus.smbj.event.SMBEventBus
import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor
import com.hierynomus.smbj.testing.StubAuthenticator
import com.hierynomus.smbj.testing.StubTransportLayerFactory
import spock.lang.Specification

class ProtocolNegotiatorSpec extends Specification {
def bus = new SMBEventBus()

private SmbConfig buildConfig(SmbConfig.Builder builder, packetProcessor) {
builder
.withTransportLayerFactory(new StubTransportLayerFactory(new BasicPacketProcessor(packetProcessor).&processPacket))
.withTransportLayerFactory(new StubTransportLayerFactory(new DefaultPacketProcessor().wrap(packetProcessor)))
.withAuthenticators(new StubAuthenticator.Factory())
.build()
}
Expand Down

This file was deleted.

Loading

0 comments on commit 8b921f0

Please sign in to comment.