/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.HashSet;
import java.util.Set;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.service.IoAcceptor;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.sshd.client.channel.ChannelDirectTcpip;
import org.apache.sshd.client.future.AuthFuture;
import org.apache.sshd.client.future.ConnectFuture;
import org.apache.sshd.common.SshdSocketAddress;
import org.apache.sshd.util.BaseTest;
import org.apache.sshd.util.BogusForwardingFilter;
import org.apache.sshd.util.BogusPasswordAuthenticator;
import org.apache.sshd.util.EchoShellFactory;
import org.apache.sshd.util.JSchLogger;
import org.apache.sshd.util.SimpleUserInfo;
import org.apache.sshd.util.Utils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.LoggerFactory;
import static org.apache.sshd.util.Utils.getFreePort;
import static org.junit.Assert.assertEquals;
/**
* Port forwarding tests
*/
public class PortForwardingTest
extends BaseTest
{
private final org.
slf4j.
Logger log = LoggerFactory.
getLogger(getClass
());
private SshServer sshd
;
private int sshPort
;
private int echoPort
;
private IoAcceptor acceptor
;
private SshClient client
;
@Before
public void setUp
() throws Exception {
sshPort = getFreePort
();
echoPort = getFreePort
();
sshd = SshServer.
setUpDefaultServer();
sshd.
getProperties().
put(SshServer.
WINDOW_SIZE,
"2048");
sshd.
getProperties().
put(SshServer.
MAX_PACKET_SIZE,
"256");
sshd.
setPort(sshPort
);
sshd.
setKeyPairProvider(Utils.
createTestHostKeyProvider());
sshd.
setShellFactory(new EchoShellFactory
());
sshd.
setPasswordAuthenticator(new BogusPasswordAuthenticator
());
sshd.
setTcpipForwardingFilter(new BogusForwardingFilter
());
sshd.
start();
NioSocketAcceptor acceptor =
new NioSocketAcceptor
();
acceptor.
setHandler(new IoHandlerAdapter
() {
@
Override
public void messageReceived
(IoSession session,
Object message
) throws Exception {
IoBuffer recv =
(IoBuffer
) message
;
IoBuffer sent = IoBuffer.
allocate(recv.
remaining());
sent.
put(recv
);
sent.
flip();
session.
write(sent
);
}
});
acceptor.
setReuseAddress(true);
acceptor.
bind(new InetSocketAddress(echoPort
));
this.
acceptor = acceptor
;
}
@After
public void tearDown
() throws Exception {
if (sshd
!=
null) {
sshd.
stop(true);
}
if (acceptor
!=
null) {
acceptor.
dispose(true);
}
if (client
!=
null) {
client.
stop();
}
}
@Test
public void testRemoteForwarding
() throws Exception {
Session session = createSession
();
int forwardedPort = getFreePort
();
session.
setPortForwardingR(forwardedPort,
"localhost", echoPort
);
Thread.
sleep(100);
Socket s =
new Socket("localhost", forwardedPort
);
s.
getOutputStream().
write("Hello".
getBytes());
s.
getOutputStream().
flush();
byte[] buf =
new byte[1024];
int n = s.
getInputStream().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("Hello", res
);
s.
close();
session.
delPortForwardingR(forwardedPort
);
session.
disconnect();
}
@Test
public void testRemoteForwardingNative
() throws Exception {
ClientSession session = createNativeSession
();
int forwardedPort = getFreePort
();
SshdSocketAddress remote =
new SshdSocketAddress
("", forwardedPort
);
SshdSocketAddress local =
new SshdSocketAddress
("localhost", echoPort
);
session.
startRemotePortForwarding(remote, local
);
Socket s =
new Socket(remote.
getHostName(), remote.
getPort());
s.
getOutputStream().
write("Hello".
getBytes());
s.
getOutputStream().
flush();
byte[] buf =
new byte[1024];
int n = s.
getInputStream().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("Hello", res
);
s.
close();
session.
stopRemotePortForwarding(remote
);
session.
close(false).
await();
}
@Test
public void testRemoteForwardingNativeBigPayload
() throws Exception {
ClientSession session = createNativeSession
();
int forwardedPort = getFreePort
();
SshdSocketAddress remote =
new SshdSocketAddress
("", forwardedPort
);
SshdSocketAddress local =
new SshdSocketAddress
("localhost", echoPort
);
session.
startRemotePortForwarding(remote, local
);
byte[] buf =
new byte[1024];
Socket s =
new Socket(remote.
getHostName(), remote.
getPort());
for (int i =
0; i
< 1000; i++
) {
s.
getOutputStream().
write("0123456789".
getBytes());
s.
getOutputStream().
flush();
int n = s.
getInputStream().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("0123456789", res
);
}
s.
close();
session.
stopRemotePortForwarding(remote
);
session.
close(false).
await();
}
@Test
public void testRemoteForwardingNativeNoExplicitPort
() throws Exception {
ClientSession session = createNativeSession
();
SshdSocketAddress remote =
new SshdSocketAddress
("0.0.0.0",
0);
SshdSocketAddress local =
new SshdSocketAddress
("localhost", echoPort
);
SshdSocketAddress bound = session.
startRemotePortForwarding(remote, local
);
Socket s =
new Socket(bound.
getHostName(), bound.
getPort());
s.
getOutputStream().
write("Hello".
getBytes());
s.
getOutputStream().
flush();
byte[] buf =
new byte[1024];
int n = s.
getInputStream().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("Hello", res
);
s.
close();
session.
stopRemotePortForwarding(bound
);
session.
close(false).
await();
}
@Test
public void testLocalForwarding
() throws Exception {
Session session = createSession
();
int forwardedPort = getFreePort
();
session.
setPortForwardingL(forwardedPort,
"localhost", echoPort
);
Socket s =
new Socket("localhost", forwardedPort
);
s.
getOutputStream().
write("Hello".
getBytes());
s.
getOutputStream().
flush();
byte[] buf =
new byte[1024];
int n = s.
getInputStream().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("Hello", res
);
s.
close();
session.
delPortForwardingL(forwardedPort
);
session.
disconnect();
}
@Test
public void testLocalForwardingNative
() throws Exception {
ClientSession session = createNativeSession
();
SshdSocketAddress local =
new SshdSocketAddress
("", getFreePort
());
SshdSocketAddress remote =
new SshdSocketAddress
("localhost", echoPort
);
SshdSocketAddress bound = session.
startLocalPortForwarding(local, remote
);
Socket s =
new Socket(bound.
getHostName(), bound.
getPort());
s.
getOutputStream().
write("Hello".
getBytes());
s.
getOutputStream().
flush();
byte[] buf =
new byte[1024];
int n = s.
getInputStream().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("Hello", res
);
s.
close();
session.
stopLocalPortForwarding(bound
);
session.
close(false).
await();
}
@Test
public void testLocalForwardingNativeBigPayload
() throws Exception {
ClientSession session = createNativeSession
();
SshdSocketAddress local =
new SshdSocketAddress
("", getFreePort
());
SshdSocketAddress remote =
new SshdSocketAddress
("localhost", echoPort
);
SshdSocketAddress bound = session.
startLocalPortForwarding(local, remote
);
byte[] buf =
new byte[1024];
Socket s =
new Socket(bound.
getHostName(), bound.
getPort());
for (int i =
0; i
< 1000; i++
) {
s.
getOutputStream().
write("Hello".
getBytes());
s.
getOutputStream().
flush();
int n = s.
getInputStream().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("Hello", res
);
}
s.
close();
session.
stopLocalPortForwarding(bound
);
session.
close(false).
await();
}
@Test
public void testForwardingChannel
() throws Exception {
ClientSession session = createNativeSession
();
int forwardedPort = getFreePort
();
SshdSocketAddress local =
new SshdSocketAddress
("", forwardedPort
);
SshdSocketAddress remote =
new SshdSocketAddress
("localhost", echoPort
);
ChannelDirectTcpip channel = session.
createDirectTcpipChannel(local, remote
);
channel.
open().
await();
channel.
getInvertedIn().
write("Hello".
getBytes());
channel.
getInvertedIn().
flush();
byte[] buf =
new byte[1024];
int n = channel.
getInvertedOut().
read(buf
);
String res =
new String(buf,
0, n
);
assertEquals
("Hello", res
);
channel.
close(false);
session.
close(false).
await();
}
@Test
(timeout =
20000)
public void testRemoteForwardingWithDisconnect
() throws Exception {
Session session = createSession
();
// 1. Create a Port Forward
int forwardedPort = getFreePort
();
session.
setPortForwardingR(forwardedPort,
"localhost", echoPort
);
// 2. Establish a connection through it
new Socket("localhost", forwardedPort
);
// 3. Simulate the client going away
rudelyDisconnectJschSession
(session
);
// 4. Make sure the NIOprocessor is not stuck
{
Thread.
sleep(1000);
// from here, we need to check all the threads running and find a
// "NioProcessor-"
// that is stuck on a PortForward.dispose
ThreadGroup root =
Thread.
currentThread().
getThreadGroup().
getParent();
while (root.
getParent() !=
null) {
root = root.
getParent();
}
boolean stuck
;
do {
stuck =
false;
for (Thread t : findThreads
(root,
"NioProcessor-")) {
stuck =
true;
}
try {
Thread.
sleep(1000);
} catch (InterruptedException e
) {
}
} while (stuck
);
}
session.
delPortForwardingR(forwardedPort
);
session.
disconnect();
}
/**
* Close the socket inside this JSCH session. Use reflection to find it and
* just close it.
*
* @param session
* the Session to violate
* @throws Exception
*/
private void rudelyDisconnectJschSession
(Session session
) throws Exception {
Field fSocket = session.
getClass().
getDeclaredField("socket");
fSocket.
setAccessible(true);
Socket socket =
(Socket) fSocket.
get(session
);
Assert.
assertTrue("socket is not connected", socket.
isConnected());
Assert.
assertFalse("socket should not be closed", socket.
isClosed());
socket.
close();
Assert.
assertTrue("socket has not closed", socket.
isClosed());
}
private Set<Thread> findThreads
(ThreadGroup group,
String name
) {
HashSet<Thread> ret =
new HashSet<Thread>();
int numThreads = group.
activeCount();
Thread[] threads =
new Thread[numThreads
* 2];
numThreads = group.
enumerate(threads,
false);
// Enumerate each thread in `group'
for (int i =
0; i
< numThreads
; ++i
) {
// Get thread
// log.debug("Thread name: " + threads[i].getName());
if (checkThreadForPortForward
(threads
[i
], name
)) {
ret.
add(threads
[i
]);
}
}
// didn't find the thread to check the
int numGroups = group.
activeGroupCount();
ThreadGroup[] groups =
new ThreadGroup[numGroups
* 2];
numGroups = group.
enumerate(groups,
false);
for (int i =
0; i
< numGroups
; ++i
) {
ret.
addAll(findThreads
(groups
[i
], name
));
}
return ret
;
}
private boolean checkThreadForPortForward
(Thread thread,
String name
) {
if (thread ==
null)
return false;
// does it contain the name we're looking for?
if (thread.
getName().
contains(name
)) {
// look at the stack
StackTraceElement[] stack = thread.
getStackTrace();
if (stack.
length ==
0)
return false;
else {
// does it have
// 'org.apache.sshd.server.session.TcpipForwardSupport.close'?
for (int i =
0; i
< stack.
length; ++i
) {
String clazzName = stack
[i
].
getClassName();
String methodName = stack
[i
].
getMethodName();
// log.debug("Class: " + clazzName);
// log.debug("Method: " + methodName);
if (clazzName
.
equals("org.apache.sshd.server.session.TcpipForwardSupport")
&& (methodName.
equals("close") || methodName
.
equals("sessionCreated"))) {
log.
warn(thread.
getName() +
" stuck at " + clazzName
+
"." + methodName +
": "
+ stack
[i
].
getLineNumber());
return true;
}
}
}
}
return false;
}
protected Session createSession
() throws JSchException
{
JSchLogger.
init();
JSch sch =
new JSch
();
Session session = sch.
getSession("sshd",
"localhost", sshPort
);
session.
setUserInfo(new SimpleUserInfo
("sshd"));
session.
connect();
return session
;
}
protected ClientSession createNativeSession
() throws Exception {
client = SshClient.
setUpDefaultClient();
client.
getProperties().
put(SshServer.
WINDOW_SIZE,
"2048");
client.
getProperties().
put(SshServer.
MAX_PACKET_SIZE,
"256");
client.
setTcpipForwardingFilter(new BogusForwardingFilter
());
client.
start();
ClientSession session = client.
connect("sshd",
"localhost", sshPort
).
await().
getSession();
session.
addPasswordIdentity("sshd");
session.
auth().
verify();
return session
;
}
}