diff --git a/src/main/java/com/hierynomus/sshj/transport/IdentificationStringParser.java b/src/main/java/com/hierynomus/sshj/transport/IdentificationStringParser.java new file mode 100644 index 000000000..46ff10982 --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/transport/IdentificationStringParser.java @@ -0,0 +1,83 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.transport; + +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.ByteArrayUtils; +import net.schmizz.sshj.transport.TransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; + +public class IdentificationStringParser { + private static final Logger logger = LoggerFactory.getLogger(IdentificationStringParser.class); + private final Buffer.PlainBuffer buffer; + + private byte[] EXPECTED_START_BYTES = new byte[] {'S', 'S', 'H', '-'}; + + public IdentificationStringParser(Buffer.PlainBuffer buffer) { + this.buffer = buffer; + } + + public String parseIdentificationString() throws IOException { + for (;;) { + Buffer.PlainBuffer lineBuffer = new Buffer.PlainBuffer(); + int lineStartPos = buffer.rpos(); + for (;;) { + if (buffer.available() == 0) { + buffer.rpos(lineStartPos); + return ""; + } + byte b = buffer.readByte(); + lineBuffer.putByte(b); + if (b == '\n') { + if (checkForIdentification(lineBuffer)) { + return readIdentification(lineBuffer); + } + break; + } + } + } + } + + private String readIdentification(Buffer.PlainBuffer lineBuffer) throws Buffer.BufferException, TransportException { + byte[] bytes = new byte[lineBuffer.available()]; + lineBuffer.readRawBytes(bytes); + if (bytes.length > 255) { + logger.error("Incorrect identification String received, line was longer than expected: {}", new String(bytes)); + logger.error("Just for good measure, bytes were: {}", ByteArrayUtils.printHex(bytes, 0, bytes.length)); + throw new TransportException("Incorrect identification: line too long: " + ByteArrayUtils.printHex(bytes, 0, bytes.length)); + } + if (bytes[bytes.length - 2] != '\r') { + logger.error("Incorrect identification, was expecting a '\\r\\n' however got: '{}' (hex: {})", bytes[bytes.length - 2], Integer.toHexString(bytes[bytes.length - 2] & 0xFF)); + logger.error("Data received up til here was: {}", new String(bytes)); + throw new TransportException("Incorrect identification: bad line ending: " + ByteArrayUtils.toHex(bytes, 0, bytes.length)); + } + + // Strip off the \r\n + return new String(bytes, 0, bytes.length - 2); + } + + private boolean checkForIdentification(Buffer.PlainBuffer lineBuffer) throws Buffer.BufferException { + byte[] buf = new byte[4]; + lineBuffer.readRawBytes(buf); + // Reset + lineBuffer.rpos(0); + return Arrays.equals(EXPECTED_START_BYTES, buf); + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 4fef40819..fd3a51457 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -15,18 +15,14 @@ */ package net.schmizz.sshj.transport; +import com.hierynomus.sshj.transport.IdentificationStringParser; import net.schmizz.concurrent.ErrorDeliveryUtil; import net.schmizz.concurrent.Event; import net.schmizz.sshj.AbstractService; import net.schmizz.sshj.Config; import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.Service; -import net.schmizz.sshj.common.Buffer; -import net.schmizz.sshj.common.DisconnectReason; -import net.schmizz.sshj.common.IOUtils; -import net.schmizz.sshj.common.Message; -import net.schmizz.sshj.common.SSHException; -import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.common.*; import net.schmizz.sshj.transport.verification.AlgorithmsVerifier; import net.schmizz.sshj.transport.verification.HostKeyVerifier; import org.slf4j.Logger; @@ -207,38 +203,47 @@ private void sendClientIdent() throws IOException { */ private String readIdentification(Buffer.PlainBuffer buffer) throws IOException { - String ident; - - byte[] data = new byte[256]; - for (; ; ) { - int savedBufPos = buffer.rpos(); - int pos = 0; - boolean needLF = false; - for (; ; ) { - if (buffer.available() == 0) { - // Need more data, so undo reading and return null - buffer.rpos(savedBufPos); - return ""; - } - byte b = buffer.readByte(); - if (b == '\r') { - needLF = true; - continue; - } - if (b == '\n') - break; - if (needLF) - throw new TransportException("Incorrect identification: bad line ending"); - if (pos >= data.length) - throw new TransportException("Incorrect identification: line too long"); - data[pos++] = b; - } - ident = new String(data, 0, pos); - if (ident.startsWith("SSH-")) - break; - if (buffer.rpos() > 16 * 1024) - throw new TransportException("Incorrect identification: too many header lines"); + String ident = new IdentificationStringParser(buffer).parseIdentificationString(); + if (ident.isEmpty()) { + return ident; } +// +// byte[] data = new byte[256]; +// for (; ; ) { +// int savedBufPos = buffer.rpos(); +// int pos = 0; +// boolean needLF = false; +// for (; ; ) { +// if (buffer.available() == 0) { +// // Need more data, so undo reading and return null +// buffer.rpos(savedBufPos); +// return ""; +// } +// byte b = buffer.readByte(); +// if (b == '\r') { +// needLF = true; +// continue; +// } +// if (b == '\n') +// break; +// if (needLF) { +// log.error("Incorrect identification, was expecting a '\n' after the '\r', got: '{}' (hex: {})", b, Integer.toHexString(b & 0xFF)); +// log.error("Data received up til here was: {}", new String(data, 0, pos)); +// throw new TransportException("Incorrect identification: bad line ending: " + ByteArrayUtils.toHex(data, 0, pos)); +// } +// if (pos >= data.length) { +// log.error("Incorrect identification String received, line was longer than expected: {}", new String(data, 0, pos)); +// log.error("Just for good measure, bytes were: {}", ByteArrayUtils.printHex(data, 0, pos)); +// throw new TransportException("Incorrect identification: line too long: " + ByteArrayUtils.printHex(data, 0, pos)); +// } +// data[pos++] = b; +// } +// ident = new String(data, 0, pos); +// if (ident.startsWith("SSH-")) +// break; +// if (buffer.rpos() > 16 * 1024) +// throw new TransportException("Incorrect identification: too many header lines"); +// } if (!ident.startsWith("SSH-2.0-") && !ident.startsWith("SSH-1.99-")) throw new TransportException(DisconnectReason.PROTOCOL_VERSION_NOT_SUPPORTED, diff --git a/src/test/groovy/com/hierynomus/sshj/transport/IdentificationStringParserSpec.groovy b/src/test/groovy/com/hierynomus/sshj/transport/IdentificationStringParserSpec.groovy new file mode 100644 index 000000000..72a401be8 --- /dev/null +++ b/src/test/groovy/com/hierynomus/sshj/transport/IdentificationStringParserSpec.groovy @@ -0,0 +1,78 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.transport + +import net.schmizz.sshj.common.Buffer +import net.schmizz.sshj.transport.TransportException +import spock.lang.Specification + +class IdentificationStringParserSpec extends Specification { + + def "should parse simple identification string"() { + given: + def buffer = new Buffer.PlainBuffer() + buffer.putRawBytes("SSH-2.0-OpenSSH-6.13\r\n".bytes) + + when: + def ident = new IdentificationStringParser(buffer).parseIdentificationString() + + then: + ident == "SSH-2.0-OpenSSH-6.13" + } + + def "should not parse header lines as part of ident"() { + given: + def buffer = new Buffer.PlainBuffer() + buffer.putRawBytes("header1\nheader2\r\nSSH-2.0-OpenSSH-6.13\r\n".bytes) + + when: + def ident = new IdentificationStringParser(buffer).parseIdentificationString() + + then: + ident == "SSH-2.0-OpenSSH-6.13" + } + + def "should fail on too long ident string"() { + given: + def buffer = new Buffer.PlainBuffer() + buffer.putRawBytes("SSH-2.0-OpenSSH-6.13 ".bytes) + byte[] bs = new byte[255 - buffer.wpos()] + Arrays.fill(bs, 'a'.bytes[0]) + buffer.putRawBytes(bs).putRawBytes("\r\n".bytes) + + when: + new IdentificationStringParser(buffer).parseIdentificationString() + + then: + thrown(TransportException.class) + } + + def "should not fail on too long header line"() { + given: + def buffer = new Buffer.PlainBuffer() + buffer.putRawBytes("header1 ".bytes) + byte[] bs = new byte[255 - buffer.wpos()] + new Random().nextBytes(bs) + buffer.putRawBytes(bs).putRawBytes("\r\n".bytes) + buffer.putRawBytes("SSH-2.0-OpenSSH-6.13\r\n".bytes) + + when: + def ident = new IdentificationStringParser(buffer).parseIdentificationString() + + then: + ident == "SSH-2.0-OpenSSH-6.13" + } +}