diff --git a/build.gradle b/build.gradle index 29087fc3..26f74771 100644 --- a/build.gradle +++ b/build.gradle @@ -90,6 +90,7 @@ testing { configureEach { useJUnitJupiter() dependencies { + implementation project() implementation SLF4J_API // implementation GROOVY_ALL runtimeOnly CGLIB_NODEP @@ -119,19 +120,7 @@ testing { } } - test { - sources { - groovy { - srcDirs = ['src/test/groovy'] - } - } - } - integrationTest(JvmTestSuite) { - dependencies { - implementation project() - } - sources { java { srcDirs = ['src/it/java'] diff --git a/src/main/java/com/hierynomus/mssmb2/SMB2Error.java b/src/main/java/com/hierynomus/mssmb2/SMB2Error.java index 8b877bd1..26d58916 100644 --- a/src/main/java/com/hierynomus/mssmb2/SMB2Error.java +++ b/src/main/java/com/hierynomus/mssmb2/SMB2Error.java @@ -30,7 +30,7 @@ public class SMB2Error { private List errorData = new ArrayList<>(); - SMB2Error() { + public SMB2Error() { } SMB2Error read(SMB2PacketHeader header, SMBBuffer buffer) throws Buffer.BufferException { diff --git a/src/main/java/com/hierynomus/mssmb2/SMB2Packet.java b/src/main/java/com/hierynomus/mssmb2/SMB2Packet.java index 4d4a6777..2e1b46ab 100644 --- a/src/main/java/com/hierynomus/mssmb2/SMB2Packet.java +++ b/src/main/java/com/hierynomus/mssmb2/SMB2Packet.java @@ -28,7 +28,7 @@ public class SMB2Packet extends SMBPacket { protected int structureSize; private SMB2Error error; - protected SMB2Packet() { + public SMB2Packet() { super(new SMB2PacketHeader()); } @@ -144,6 +144,10 @@ public SMB2Error getError() { return error; } + public void setError(SMB2Error error) { + this.error = error; + } + /** * Method that can be overridden by Packet Wrappers to ensure that the original (typed) packet is obtainable. * diff --git a/src/main/java/com/hierynomus/mssmb2/SMB2PacketHeader.java b/src/main/java/com/hierynomus/mssmb2/SMB2PacketHeader.java index a0c0ac13..6ec1e9ff 100644 --- a/src/main/java/com/hierynomus/mssmb2/SMB2PacketHeader.java +++ b/src/main/java/com/hierynomus/mssmb2/SMB2PacketHeader.java @@ -107,7 +107,7 @@ public void setMessageId(long messageId) { this.messageId = messageId; } - void setMessageType(SMB2MessageCommandCode messageType) { + public void setMessageType(SMB2MessageCommandCode messageType) { this.message = messageType; } @@ -151,10 +151,18 @@ public void setCreditRequest(int creditRequest) { this.creditRequest = creditRequest; } + public int getCreditRequest() { + return creditRequest; + } + public int getCreditResponse() { return creditResponse; } + public void setCreditResponse(int creditResponse) { + this.creditResponse = creditResponse; + } + public void setAsyncId(long asyncId) { this.asyncId = asyncId; } diff --git a/src/main/java/com/hierynomus/mssmb2/messages/SMB2NegotiateResponse.java b/src/main/java/com/hierynomus/mssmb2/messages/SMB2NegotiateResponse.java index e8dce790..0025f49f 100644 --- a/src/main/java/com/hierynomus/mssmb2/messages/SMB2NegotiateResponse.java +++ b/src/main/java/com/hierynomus/mssmb2/messages/SMB2NegotiateResponse.java @@ -156,4 +156,16 @@ public FileTime getServerStartTime() { public List getNegotiateContextList() { return negotiateContextList; } + + public void setDialect(SMB2Dialect dialect) { + this.dialect = dialect; + } + + public void setSystemTime(FileTime systemTime) { + this.systemTime = systemTime; + } + + public void setServerGuid(UUID serverGuid) { + this.serverGuid = serverGuid; + } } diff --git a/src/main/java/com/hierynomus/mssmb2/messages/SMB2SessionSetup.java b/src/main/java/com/hierynomus/mssmb2/messages/SMB2SessionSetup.java index 1a02091b..16442332 100644 --- a/src/main/java/com/hierynomus/mssmb2/messages/SMB2SessionSetup.java +++ b/src/main/java/com/hierynomus/mssmb2/messages/SMB2SessionSetup.java @@ -99,6 +99,10 @@ public Set getSessionFlags() { return sessionFlags; } + public void setSessionFlags(Set sessionFlags) { + this.sessionFlags = sessionFlags; + } + public void setPreviousSessionId(long previousSessionId) { this.previousSessionId = previousSessionId; } diff --git a/src/main/java/com/hierynomus/mssmb2/messages/SMB2TreeConnectResponse.java b/src/main/java/com/hierynomus/mssmb2/messages/SMB2TreeConnectResponse.java index bddb7248..0187fa8a 100644 --- a/src/main/java/com/hierynomus/mssmb2/messages/SMB2TreeConnectResponse.java +++ b/src/main/java/com/hierynomus/mssmb2/messages/SMB2TreeConnectResponse.java @@ -48,6 +48,10 @@ protected void readMessage(SMBBuffer buffer) throws Buffer.BufferException { maximalAccess = toEnumSet(buffer.readUInt32(), AccessMask.class); // MaximalAccess (4 bytes) } + public void setShareType(byte shareType) { + this.shareType = shareType; + } + /** * Whether the ShareType returned is SMB2_SHARE_TYPE_DISK (0x01) * @@ -79,10 +83,18 @@ public Set getShareFlags() { return shareFlags; } + public void setShareFlags(Set shareFlags) { + this.shareFlags = shareFlags; + } + public Set getCapabilities() { return capabilities; } + public void setCapabilities(Set capabilities) { + this.capabilities = capabilities; + } + public Set getMaximalAccess() { return maximalAccess; } diff --git a/src/main/java/com/hierynomus/smbj/SMBClient.java b/src/main/java/com/hierynomus/smbj/SMBClient.java index 71041211..584f1bbf 100644 --- a/src/main/java/com/hierynomus/smbj/SMBClient.java +++ b/src/main/java/com/hierynomus/smbj/SMBClient.java @@ -112,6 +112,7 @@ private void connectionClosed(ConnectionClosed event) { synchronized (this) { String hostPort = event.getHostname() + ":" + event.getPort(); connectionTable.remove(hostPort); + serverList.unregister(event.getHostname()); logger.debug("Connection to << {} >> closed", hostPort); } } diff --git a/src/main/java/com/hierynomus/smbj/connection/Connection.java b/src/main/java/com/hierynomus/smbj/connection/Connection.java index 9226b9a5..60235f9a 100644 --- a/src/main/java/com/hierynomus/smbj/connection/Connection.java +++ b/src/main/java/com/hierynomus/smbj/connection/Connection.java @@ -80,7 +80,7 @@ public class Connection extends Pooled implements Closeable, PacketR private SessionTable preauthSessionTable = new SessionTable(); OutstandingRequests outstandingRequests = new OutstandingRequests(); SequenceWindow sequenceWindow; - private SMB2MessageConverter smb2Converter = new SMB2MessageConverter(); + private SMB2MessageConverter messageConverter = new SMB2MessageConverter(); private PathResolver pathResolver; private final SMBClient client; @@ -119,7 +119,7 @@ private void init() { new SMB2SignatureVerificationPacketHandler(sessionTable, signatory).setNext( new SMB2CreditGrantingPacketHandler(sequenceWindow).setNext( new SMB2AsyncResponsePacketHandler(outstandingRequests).setNext( - new SMB2ProcessResponsePacketHandler(smb2Converter, outstandingRequests).setNext( + new SMB2ProcessResponsePacketHandler(messageConverter, outstandingRequests).setNext( new SMB1PacketHandler().setNext(new DeadLetterPacketHandler())))))))); } @@ -382,4 +382,8 @@ SessionTable getSessionTable() { SessionTable getPreauthSessionTable() { return preauthSessionTable; } + + public void setMessageConverter(SMB2MessageConverter smb2Converter) { + this.messageConverter = smb2Converter; + } } diff --git a/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy b/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy index 03a41320..35a32344 100644 --- a/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy +++ b/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy @@ -15,14 +15,14 @@ */ package com.hierynomus.smbj -import com.hierynomus.smbj.connection.BasicPacketProcessor -import com.hierynomus.smbj.connection.StubTransportLayerFactory +import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor +import com.hierynomus.smbj.testing.StubTransportLayerFactory import spock.lang.Specification class SMBClientSpec extends Specification { - def processor = new BasicPacketProcessor({ req -> null }) - def config = SmbConfig.builder().withTransportLayerFactory(new StubTransportLayerFactory(processor.&processPacket)).build() + def processor = new DefaultPacketProcessor() + def config = SmbConfig.builder().withTransportLayerFactory(new StubTransportLayerFactory(processor)).build() def "should return same connection for same host/port combo"() { given: diff --git a/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy b/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy index 31b0b3d7..197f54ab 100644 --- a/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy +++ b/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy @@ -35,18 +35,22 @@ import com.hierynomus.smbj.event.SMBEvent import com.hierynomus.smbj.event.SMBEventBus import com.hierynomus.smbj.event.SessionLoggedOff import com.hierynomus.protocol.transport.TransportException +import com.hierynomus.smbj.testing.PacketProcessor.NoOpPacketProcessor +import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor +import com.hierynomus.smbj.testing.StubAuthenticator +import com.hierynomus.smbj.testing.StubTransportLayerFactory import net.engio.mbassy.listener.Handler import spock.lang.Specification class ConnectionSpec extends Specification { def bus = new SMBEventBus() - def packetProcessor = { req -> null } + def packetProcessor = new NoOpPacketProcessor() def config = smbConfig(packetProcessor) private SmbConfig smbConfig(packetProcessor) { SmbConfig.builder() - .withTransportLayerFactory(new StubTransportLayerFactory(new BasicPacketProcessor(packetProcessor).&processPacket)) + .withTransportLayerFactory(new StubTransportLayerFactory(new DefaultPacketProcessor().wrap(packetProcessor))) .withAuthenticators(new StubAuthenticator.Factory()) .build() } @@ -152,41 +156,6 @@ class ConnectionSpec extends Specification { !(conn.pathResolver instanceof DFSPathResolver) } - def "should remove server from serverlist if identification changed"() { - given: - def sent = false - config = smbConfig({ req -> - req = req.packet - if (!sent && req instanceof SMB2NegotiateRequest) { - sent = true - def response = new SMB2NegotiateResponse() - response.header.message = SMB2MessageCommandCode.SMB2_NEGOTIATE - response.header.statusCode = NtStatus.STATUS_SUCCESS.value - response.dialect = SMB2Dialect.SMB_2_1 - response.systemTime = FileTime.now(); - response.serverGuid = UUID.fromString("ffeeddcc-bbaa-9988-7766-554433221100") - return response - } - }) - client = new SMBClient(config) - - when: - def conn = client.connect("foo") - conn.close() - - conn = client.connect("foo") - - then: - thrown(TransportException) - - when: - client.getServerList().unregister("foo") - conn = client.connect("foo") - - then: - noExceptionThrown() - } - def "should add DFS path resolver if server supports DFS"() { given: config = smbConfig({ req -> diff --git a/src/test/groovy/com/hierynomus/smbj/connection/ProtocolNegotiatorSpec.groovy b/src/test/groovy/com/hierynomus/smbj/connection/ProtocolNegotiatorSpec.groovy index efb525ff..d0d3a419 100644 --- a/src/test/groovy/com/hierynomus/smbj/connection/ProtocolNegotiatorSpec.groovy +++ b/src/test/groovy/com/hierynomus/smbj/connection/ProtocolNegotiatorSpec.groovy @@ -25,6 +25,9 @@ import com.hierynomus.smbj.SMBClient import com.hierynomus.smbj.SmbConfig import com.hierynomus.smbj.event.ConnectionClosed import com.hierynomus.smbj.event.SMBEventBus +import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor +import com.hierynomus.smbj.testing.StubAuthenticator +import com.hierynomus.smbj.testing.StubTransportLayerFactory import spock.lang.Specification class ProtocolNegotiatorSpec extends Specification { @@ -32,7 +35,7 @@ class ProtocolNegotiatorSpec extends Specification { private SmbConfig buildConfig(SmbConfig.Builder builder, packetProcessor) { builder - .withTransportLayerFactory(new StubTransportLayerFactory(new BasicPacketProcessor(packetProcessor).&processPacket)) + .withTransportLayerFactory(new StubTransportLayerFactory(new DefaultPacketProcessor().wrap(packetProcessor))) .withAuthenticators(new StubAuthenticator.Factory()) .build() } diff --git a/src/test/groovy/com/hierynomus/smbj/connection/StubAuthenticator.groovy b/src/test/groovy/com/hierynomus/smbj/connection/StubAuthenticator.groovy deleted file mode 100644 index 43cc70da..00000000 --- a/src/test/groovy/com/hierynomus/smbj/connection/StubAuthenticator.groovy +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (C)2016 - SMBJ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.hierynomus.smbj.connection - -import com.hierynomus.protocol.commons.ByteArrayUtils -import com.hierynomus.security.SecurityProvider -import com.hierynomus.smbj.SmbConfig -import com.hierynomus.smbj.auth.AuthenticateResponse -import com.hierynomus.smbj.auth.AuthenticationContext -import com.hierynomus.smbj.auth.Authenticator -import com.hierynomus.smbj.session.Session -import com.hierynomus.spnego.RawToken - -class StubAuthenticator implements Authenticator { - static class Factory implements com.hierynomus.protocol.commons.Factory.Named { - - @Override - String getName() { - return "stub" - } - - @Override - StubAuthenticator create() { - return new StubAuthenticator() - } - } - - @Override - void init(SmbConfig config) { - - } - - @Override - boolean supports(AuthenticationContext context) { - return true - } - - @Override - AuthenticateResponse authenticate(AuthenticationContext context, byte[] gssToken, ConnectionContext connectionContext) throws IOException { - def resp = new AuthenticateResponse(new RawToken(new byte[0])) - resp.sessionKey = ByteArrayUtils.parseHex("09921d4431b171b977370bf8910900f9") - return resp - } -} diff --git a/src/test/groovy/com/hierynomus/smbj/connection/StubTransportLayerFactory.groovy b/src/test/groovy/com/hierynomus/smbj/connection/StubTransportLayerFactory.groovy deleted file mode 100644 index 456f36e5..00000000 --- a/src/test/groovy/com/hierynomus/smbj/connection/StubTransportLayerFactory.groovy +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (C)2016 - SMBJ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.hierynomus.smbj.connection - -import com.hierynomus.mssmb2.SMB2PacketHeader -import com.hierynomus.mssmb2.SMB2MessageConverter -import com.hierynomus.mssmb2.SMB2Packet -import com.hierynomus.mssmb2.SMB2PacketData -import com.hierynomus.protocol.commons.buffer.Buffer -import com.hierynomus.protocol.transport.PacketHandlers -import com.hierynomus.protocol.transport.PacketReceiver -import com.hierynomus.protocol.transport.TransportException -import com.hierynomus.protocol.transport.TransportLayer -import com.hierynomus.smb.SMBPacket -import com.hierynomus.smbj.SmbConfig -import com.hierynomus.smbj.transport.TransportLayerFactory - -class StubTransportLayerFactory implements TransportLayerFactory { - private Closure processPacket - - StubTransportLayerFactory(Closure processPacket) { - this.processPacket = processPacket - } - - @Override - TransportLayer createTransportLayer(PacketHandlers handlers, SmbConfig config) { - return new StubTransportLayer(handlers.receiver, processPacket) - } - - private static class StubTransportLayer implements TransportLayer { - private boolean connected - private PacketReceiver receiver - private Closure processPacket - - StubTransportLayer(PacketReceiver receiver, Closure processPacket) { - this.receiver = receiver - if (this.receiver instanceof Connection) { - ((Connection) this.receiver).smb2Converter = new StubMessageConverter() - } - this.processPacket = processPacket - } - - @Override - void write(SMB2Packet packet) throws TransportException { - def response = processPacket.call(packet) - - if (response != null) { - response.header.messageId = packet.header.messageId - response.header.creditResponse = packet.header.creditRequest - receiver.handle(new StubPacketData(response)) - } else { - throw new TransportException("No response for " + packet) - } - } - - @Override - void connect(InetSocketAddress remoteAddress) throws IOException { - connected = true - } - - @Override - void disconnect() throws IOException { - connected = false - } - - @Override - boolean isConnected() { - return connected - } - } - - private static class StubPacketData extends SMB2PacketData { - private SMB2Packet packet - - StubPacketData(SMB2Packet packet) throws Buffer.BufferException { - super(new byte[0]) - this.packet = packet - } - - @Override - SMB2PacketHeader getHeader() { - return packet.header - } - - @Override - protected void readHeader() throws Buffer.BufferException { - } - } - - private static class StubMessageConverter extends SMB2MessageConverter { - private realConverter = new SMB2MessageConverter() - @Override - SMB2Packet readPacket(SMBPacket requestPacket, SMB2PacketData packetData) throws Buffer.BufferException { - if (packetData instanceof StubPacketData) { - return packetData.packet - } else { - return realConverter.readPacket(requestPacket, packetData) - } - } - } -} diff --git a/src/test/groovy/com/hierynomus/smbj/share/FileOutputStreamSpec.groovy b/src/test/groovy/com/hierynomus/smbj/share/FileOutputStreamSpec.groovy index 19dd3339..32633a24 100644 --- a/src/test/groovy/com/hierynomus/smbj/share/FileOutputStreamSpec.groovy +++ b/src/test/groovy/com/hierynomus/smbj/share/FileOutputStreamSpec.groovy @@ -26,10 +26,10 @@ import com.hierynomus.mssmb2.messages.SMB2WriteResponse import com.hierynomus.smbj.SMBClient import com.hierynomus.smbj.SmbConfig import com.hierynomus.smbj.auth.AuthenticationContext -import com.hierynomus.smbj.connection.BasicPacketProcessor import com.hierynomus.smbj.connection.Connection -import com.hierynomus.smbj.connection.StubAuthenticator -import com.hierynomus.smbj.connection.StubTransportLayerFactory +import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor +import com.hierynomus.smbj.testing.StubAuthenticator +import com.hierynomus.smbj.testing.StubTransportLayerFactory import spock.lang.Specification class FileOutputStreamSpec extends Specification { @@ -39,7 +39,7 @@ class FileOutputStreamSpec extends Specification { def setup() { devNull = new ByteArrayOutputStream() - def responder = new BasicPacketProcessor({ req -> + def responder = new DefaultPacketProcessor().wrap({ req -> if (req.packet instanceof SMB2CreateRequest) return createResponse() if (req.packet instanceof SMB2WriteRequest) @@ -51,7 +51,7 @@ class FileOutputStreamSpec extends Specification { def config = SmbConfig.builder() .withReadBufferSize(1024) .withDfsEnabled(false) - .withTransportLayerFactory(new StubTransportLayerFactory(responder.&processPacket)) + .withTransportLayerFactory(new StubTransportLayerFactory(responder)) .withAuthenticators(new StubAuthenticator.Factory()) .build() def client = new SMBClient(config) diff --git a/src/test/groovy/com/hierynomus/smbj/share/FileReadSpec.groovy b/src/test/groovy/com/hierynomus/smbj/share/FileReadSpec.groovy index cc449419..4bc54fd5 100644 --- a/src/test/groovy/com/hierynomus/smbj/share/FileReadSpec.groovy +++ b/src/test/groovy/com/hierynomus/smbj/share/FileReadSpec.groovy @@ -19,6 +19,7 @@ import com.hierynomus.msdtyp.AccessMask import com.hierynomus.mserref.NtStatus import com.hierynomus.msfscc.FileAttributes import com.hierynomus.mssmb2.* +import com.hierynomus.mssmb2.SMB2Packet import com.hierynomus.mssmb2.messages.SMB2CreateRequest import com.hierynomus.mssmb2.messages.SMB2CreateResponse import com.hierynomus.mssmb2.messages.SMB2ReadRequest @@ -27,10 +28,11 @@ import com.hierynomus.protocol.commons.ByteArrayUtils import com.hierynomus.smbj.SMBClient import com.hierynomus.smbj.SmbConfig import com.hierynomus.smbj.auth.AuthenticationContext -import com.hierynomus.smbj.connection.BasicPacketProcessor import com.hierynomus.smbj.connection.Connection -import com.hierynomus.smbj.connection.StubAuthenticator -import com.hierynomus.smbj.connection.StubTransportLayerFactory +import com.hierynomus.smbj.testing.PacketProcessor +import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor +import com.hierynomus.smbj.testing.StubAuthenticator +import com.hierynomus.smbj.testing.StubTransportLayerFactory import spock.lang.Specification import java.security.DigestOutputStream @@ -42,12 +44,11 @@ class FileReadSpec extends Specification { private MessageDigest digest private File file private Connection connection - private BasicPacketProcessor responder + private PacketProcessor responder def setup() { fileData = randomData(42, 12345) - - responder = new BasicPacketProcessor({ req -> + responder = new DefaultPacketProcessor().wrap({ req -> req = req.packet if (req instanceof SMB2CreateRequest) return createResponse() @@ -60,7 +61,7 @@ class FileReadSpec extends Specification { def config = SmbConfig.builder() .withReadBufferSize(1024) .withDfsEnabled(false) - .withTransportLayerFactory(new StubTransportLayerFactory(responder.&processPacket)) + .withTransportLayerFactory(new StubTransportLayerFactory(responder)) .withAuthenticators(new StubAuthenticator.Factory()) .build() def client = new SMBClient(config) @@ -136,12 +137,12 @@ class FileReadSpec extends Specification { def "should read entire file contents via input stream in IBM mode"() { when: - responder.addBehaviour { SMB2Packet req -> + responder = responder.wrap({ SMB2Packet req -> if (req instanceof SMB2ReadRequest) return read(req, fileData, true) null - } + }) def out = new DigestOutputStream(new ByteArrayOutputStream(), digest) def buffer = new byte[10] diff --git a/src/test/java/com/hierynomus/smbj/connection/ConnectionTest.java b/src/test/java/com/hierynomus/smbj/connection/ConnectionTest.java new file mode 100644 index 00000000..f1c10d35 --- /dev/null +++ b/src/test/java/com/hierynomus/smbj/connection/ConnectionTest.java @@ -0,0 +1,117 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.connection; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.jupiter.api.Test; + +import com.hierynomus.msdtyp.FileTime; +import com.hierynomus.mserref.NtStatus; +import com.hierynomus.mssmb2.SMB2Dialect; +import com.hierynomus.mssmb2.SMB2MessageCommandCode; +import com.hierynomus.mssmb2.SMB2Packet; +import com.hierynomus.mssmb2.messages.SMB2NegotiateRequest; +import com.hierynomus.mssmb2.messages.SMB2NegotiateResponse; +import com.hierynomus.smbj.SMBClient; +import com.hierynomus.smbj.SmbConfig; +import com.hierynomus.smbj.testing.PacketProcessor; +import com.hierynomus.smbj.testing.StubAuthenticator; +import com.hierynomus.smbj.testing.StubTransportLayerFactory; +import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor; +import com.hierynomus.smbj.testing.PacketProcessor.NoOpPacketProcessor;; + +public class ConnectionTest { + + private static SmbConfig config(PacketProcessor processor) { + return SmbConfig.builder() + .withTransportLayerFactory(new StubTransportLayerFactory(new DefaultPacketProcessor().wrap(processor))) + .withAuthenticators(new StubAuthenticator.Factory()).build(); + } + + @Test + public void shouldUnregisterServerWhenConnectionClosed() throws Exception { + SmbConfig config = config(new NoOpPacketProcessor()); + SMBClient client = new SMBClient(config); + + Connection conn = client.connect("foo"); + assertNotNull(client.getServerList().lookup("foo")); + + conn.close(); + assertNull(client.getServerList().lookup("foo")); + } + + @Test + public void shouldNotUnregisterServerWhenNotAllConnectionsClosed() throws Exception { + SmbConfig config = config(new NoOpPacketProcessor()); + SMBClient client = new SMBClient(config); + + Connection conn = client.connect("foo"); + Connection conn2 = client.connect("foo"); + conn.close(); + + assertNotNull(client.getServerList().lookup("foo")); + + conn2.close(); + assertNull(client.getServerList().lookup("foo")); + } + + @Test + public void shouldConnectToServerWithChangedIdentificationWhenAllConnectionsClosed() throws Exception { + UUID one = UUID.fromString("ffeeddcc-bbaa-9988-7766-554433221100"); + UUID two = UUID.fromString("00112233-4455-6677-8899-aabbccddeeff"); + + final AtomicBoolean sent = new AtomicBoolean(false); + SmbConfig config = config((SMB2Packet req) -> { + req = req.getPacket(); + if (!sent.get() && req instanceof SMB2NegotiateRequest) { + sent.set(true); + SMB2NegotiateResponse resp = new SMB2NegotiateResponse(); + resp.getHeader().setMessageType(SMB2MessageCommandCode.SMB2_NEGOTIATE); + resp.getHeader().setStatusCode(NtStatus.STATUS_SUCCESS.getValue()); + resp.setDialect(SMB2Dialect.SMB_2_1); + resp.setSystemTime(FileTime.now()); + resp.setServerGuid(one); + return resp; + } + return null; + }); + + SMBClient client = new SMBClient(config); + + Connection conn = client.connect("foo"); + Connection conn2 = client.connect("foo"); + assertEquals(one, client.getServerList().lookup("foo").getServerGUID()); + assertEquals(one, conn.getConnectionContext().getServer().getServerGUID()); + assertEquals(one, conn2.getConnectionContext().getServer().getServerGUID()); + + conn.close(); + + assertNotNull(client.getServerList().lookup("foo")); + + conn2.close(); + assertNull(client.getServerList().lookup("foo")); + + conn = client.connect("foo"); + assertEquals(two, client.getServerList().lookup("foo").getServerGUID()); + assertEquals(two, conn.getConnectionContext().getServer().getServerGUID()); + } +} diff --git a/src/test/java/com/hierynomus/smbj/testing/PacketProcessor.java b/src/test/java/com/hierynomus/smbj/testing/PacketProcessor.java new file mode 100644 index 00000000..ffb7d399 --- /dev/null +++ b/src/test/java/com/hierynomus/smbj/testing/PacketProcessor.java @@ -0,0 +1,139 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.testing; + +import java.util.EnumSet; +import java.util.UUID; + +import com.hierynomus.msdtyp.FileTime; +import com.hierynomus.mserref.NtStatus; +import com.hierynomus.mssmb2.SMB2Dialect; +import com.hierynomus.mssmb2.SMB2Error; +import com.hierynomus.mssmb2.SMB2MessageCommandCode; +import com.hierynomus.mssmb2.SMB2Packet; +import com.hierynomus.mssmb2.SMB2ShareCapabilities; +import com.hierynomus.mssmb2.SMB2ShareFlags; +import com.hierynomus.mssmb2.messages.SMB2Logoff; +import com.hierynomus.mssmb2.messages.SMB2NegotiateRequest; +import com.hierynomus.mssmb2.messages.SMB2NegotiateResponse; +import com.hierynomus.mssmb2.messages.SMB2SessionSetup; +import com.hierynomus.mssmb2.messages.SMB2TreeConnectResponse; +import com.hierynomus.mssmb2.messages.SMB2TreeDisconnect; + +@FunctionalInterface +public interface PacketProcessor { + SMB2Packet process(SMB2Packet request); + + default PacketProcessor wrap(PacketProcessor processor) { + PacketProcessor self = this; + return new PacketProcessor() { + @Override + public SMB2Packet process(SMB2Packet request) { + SMB2Packet p = processor.process(request); + if (p == null) { + return self.process(request); + } + return p; + } + }; + } + + public static class NoOpPacketProcessor implements PacketProcessor { + @Override + public SMB2Packet process(SMB2Packet request) { + return null; + } + } + + public static class DefaultPacketProcessor implements PacketProcessor { + @Override + public SMB2Packet process(SMB2Packet request) { + SMB2Packet resp = null; + request = request.getPacket(); // Ensure unwrapping + switch (request.getHeader().getMessage()) { + case SMB2_NEGOTIATE: + resp = negotiateResponse((SMB2NegotiateRequest) request); + break; + case SMB2_SESSION_SETUP: + resp = sessionSetupResponse((SMB2SessionSetup) request); + break; + case SMB2_TREE_CONNECT: + resp = connectResponse(); + break; + case SMB2_TREE_DISCONNECT: + resp = disconnectResponse(); + break; + case SMB2_LOGOFF: + resp = logoffResponse(); + break; + default: + resp = error(); + } + + return resp; + } + + private SMB2NegotiateResponse negotiateResponse(SMB2NegotiateRequest request) { + SMB2NegotiateResponse response = new SMB2NegotiateResponse(); + response.getHeader().setMessageType(SMB2MessageCommandCode.SMB2_NEGOTIATE); + response.getHeader().setStatusCode(NtStatus.STATUS_SUCCESS.getValue()); + response.setDialect(SMB2Dialect.SMB_2_1); + response.setSystemTime(FileTime.now()); + response.setServerGuid(UUID.fromString("00112233-4455-6677-8899-aabbccddeeff")); + return response; + } + + private SMB2SessionSetup sessionSetupResponse(SMB2SessionSetup request) { + SMB2SessionSetup response = new SMB2SessionSetup(); + response.getHeader().setMessageType(SMB2MessageCommandCode.SMB2_SESSION_SETUP); + response.getHeader().setStatusCode(NtStatus.STATUS_SUCCESS.getValue()); + response.getHeader().setSessionId(1); + response.setSecurityBuffer(new byte[16]); + response.setSessionFlags(EnumSet.noneOf(SMB2SessionSetup.SMB2SessionFlags.class)); + return response; + } + + private static SMB2Packet logoffResponse() { + SMB2Logoff response = new SMB2Logoff(); + response.getHeader().setMessageType(SMB2MessageCommandCode.SMB2_LOGOFF); + response.getHeader().setStatusCode(NtStatus.STATUS_SUCCESS.getValue()); + return response; + } + + private static SMB2Packet connectResponse() { + SMB2TreeConnectResponse response = new SMB2TreeConnectResponse(); + response.getHeader().setStatusCode(NtStatus.STATUS_SUCCESS.getValue()); + response.setCapabilities(EnumSet.of(SMB2ShareCapabilities.SMB2_SHARE_CAP_DFS)); + response.setShareFlags(EnumSet.noneOf(SMB2ShareFlags.class)); + response.setShareType((byte) 0x01); + return response; + } + + private static SMB2Packet disconnectResponse() { + SMB2TreeDisconnect response = new SMB2TreeDisconnect(); + response.getHeader().setStatusCode(NtStatus.STATUS_SUCCESS.getValue()); + return response; + } + + private static SMB2Packet error() { + SMB2Packet p = new SMB2Packet(); + p.getHeader().setStatusCode(NtStatus.STATUS_INTERNAL_ERROR.getValue()); + p.setError(new SMB2Error()); + return p; + } + + } +} diff --git a/src/test/java/com/hierynomus/smbj/testing/StubAuthenticator.java b/src/test/java/com/hierynomus/smbj/testing/StubAuthenticator.java new file mode 100644 index 00000000..e839d5e0 --- /dev/null +++ b/src/test/java/com/hierynomus/smbj/testing/StubAuthenticator.java @@ -0,0 +1,58 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.testing; + +import java.io.IOException; + +import com.hierynomus.protocol.commons.ByteArrayUtils; +import com.hierynomus.smbj.SmbConfig; +import com.hierynomus.smbj.auth.AuthenticateResponse; +import com.hierynomus.smbj.auth.AuthenticationContext; +import com.hierynomus.smbj.auth.Authenticator; +import com.hierynomus.smbj.connection.ConnectionContext; +import com.hierynomus.spnego.RawToken; + +public class StubAuthenticator implements Authenticator { + public static class Factory implements com.hierynomus.protocol.commons.Factory.Named { + + @Override + public String getName() { + return "stub"; + } + + @Override + public StubAuthenticator create() { + return new StubAuthenticator(); + } + } + + @Override + public void init(SmbConfig config) { + } + + @Override + public boolean supports(AuthenticationContext context) { + return true; + } + + @Override + public AuthenticateResponse authenticate(AuthenticationContext context, byte[] gssToken, + ConnectionContext connectionContext) throws IOException { + AuthenticateResponse resp = new AuthenticateResponse(new RawToken(new byte[0])); + resp.setSessionKey(ByteArrayUtils.parseHex("09921d4431b171b977370bf8910900f9")); + return resp; + } +} diff --git a/src/test/java/com/hierynomus/smbj/testing/StubTransportLayerFactory.java b/src/test/java/com/hierynomus/smbj/testing/StubTransportLayerFactory.java new file mode 100644 index 00000000..be01dc8b --- /dev/null +++ b/src/test/java/com/hierynomus/smbj/testing/StubTransportLayerFactory.java @@ -0,0 +1,132 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.testing; + +import java.io.IOException; +import java.net.InetSocketAddress; + +import com.hierynomus.mssmb2.SMB2MessageConverter; +import com.hierynomus.mssmb2.SMB2Packet; +import com.hierynomus.mssmb2.SMB2PacketData; +import com.hierynomus.mssmb2.SMB2PacketHeader; +import com.hierynomus.protocol.Packet; +import com.hierynomus.protocol.PacketData; +import com.hierynomus.protocol.commons.buffer.Buffer.BufferException; +import com.hierynomus.protocol.transport.PacketHandlers; +import com.hierynomus.protocol.transport.PacketReceiver; +import com.hierynomus.protocol.transport.TransportException; +import com.hierynomus.protocol.transport.TransportLayer; +import com.hierynomus.smb.SMBPacket; +import com.hierynomus.smbj.SmbConfig; +import com.hierynomus.smbj.connection.Connection; +import com.hierynomus.smbj.transport.TransportLayerFactory; + +public class StubTransportLayerFactory, P extends Packet> + implements TransportLayerFactory { + private PacketProcessor processor; + + public StubTransportLayerFactory(PacketProcessor p) { + this.processor = p; + } + + @Override + public TransportLayer

createTransportLayer(PacketHandlers handlers, + SmbConfig config) { + return new StubTransportLayer<>(handlers.getReceiver(), processor); + } + + private static class StubTransportLayer, P extends Packet> implements TransportLayer

{ + private boolean connected; + private PacketReceiver receiver; + private PacketProcessor processPacket; + + StubTransportLayer(PacketReceiver receiver, PacketProcessor processPacket) { + this.receiver = receiver; + if (this.receiver instanceof Connection) { + ((Connection) this.receiver).setMessageConverter(new StubMessageConverter()); + } + + this.processPacket = processPacket; + } + + @Override + public void write(P packet) throws TransportException { + if (!(packet instanceof SMB2Packet)) { + throw new TransportException("Unsupported packet type " + packet.getClass().getSimpleName()); + } + SMB2Packet request = (SMB2Packet) packet; + SMB2Packet response = processPacket.process(request); + + if (response != null) { + response.getHeader().setMessageId(request.getHeader().getMessageId()); + response.getHeader().setCreditResponse(request.getHeader().getCreditRequest()); + try { + receiver.handle((D) new StubPacketData(response)); + } catch (BufferException e) { + throw new TransportException(e); + } + } else { + throw new TransportException("No response for " + packet); + } + } + + @Override + public void connect(InetSocketAddress remoteAddress) throws IOException { + connected = true; + } + + @Override + public void disconnect() throws IOException { + connected = false; + } + + @Override + public boolean isConnected() { + return connected; + } + } + + private static class StubPacketData extends SMB2PacketData { + private SMB2Packet packet; + + StubPacketData(SMB2Packet packet) throws BufferException { + super(new byte[0]); + this.packet = packet; + } + + @Override + public SMB2PacketHeader getHeader() { + return packet.getHeader(); + } + + @Override + protected void readHeader() throws BufferException { + } + } + + private static class StubMessageConverter extends SMB2MessageConverter { + private SMB2MessageConverter realConverter = new SMB2MessageConverter(); + + @Override + public SMB2Packet readPacket(SMBPacket requestPacket, SMB2PacketData packetData) throws BufferException { + if (packetData instanceof StubPacketData) { + return ((StubPacketData) packetData).packet; + } else { + return realConverter.readPacket(requestPacket, packetData); + } + } + } +}