diff --git a/ssh_daemon/daemon.py b/ssh_daemon/daemon.py index d276690..68a37db 100644 --- a/ssh_daemon/daemon.py +++ b/ssh_daemon/daemon.py @@ -1,81 +1,152 @@ import paramiko -import select import socket import threading -import subprocess +import select +import sys import os -import pty -class CustomSSHServer(paramiko.ServerInterface): +class CustomServer(paramiko.ServerInterface): + def __init__(self, username): + self.username = username + self.event = threading.Event() + def check_auth_password(self, username, password): - # Implement your authentication logic here + # Implement proper authentication here return paramiko.AUTH_SUCCESSFUL + if username == "user0" and password == "password0": + return paramiko.AUTH_SUCCESSFUL + elif username == "user1" and password == "password1": + return paramiko.AUTH_SUCCESSFUL + elif username == "user2" and password == "password2": + return paramiko.AUTH_SUCCESSFUL + return paramiko.AUTH_FAILED def check_channel_request(self, kind, chanid): - return paramiko.OPEN_SUCCEEDED + if kind == "session": + return paramiko.OPEN_SUCCEEDED + return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + + def check_channel_pty_request(self, channel, term, width, height, pixelwidth, pixelheight, modes): + return True + + def check_channel_shell_request(self, channel): + self.event.set() + return True + + def check_channel_exec_request(self, channel, command): + self.event.set() + return True def get_vm_port(username): - return 9002 + user_map = { + 'user0': 9000, + 'user1': 9001, + 'emiliko': 9002 + } + return user_map.get(username) -def handle_client(client_socket, addr): - transport = paramiko.Transport(client_socket) - # transport.add_server_key(paramiko.RSAKey(filename='server.key')) - transport.add_server_key(paramiko.Ed25519Key.from_private_key_file('daemon.key')) - server = CustomSSHServer() - transport.start_server(server=server) - - channel = transport.accept(20) - if channel is None: - print('*** No channel.') - return - - username = transport.get_username() - vm_port = get_vm_port(username) - master, slave = pty.openpty() - channel.get_pty() - channel.invoke_shell() - - ssh = subprocess.Popen( - ["ssh", "-p", "9002", "root@localhost"], - shell=True, - stdin=slave, - stdout=slave, - stderr=slave, - close_fds=True - ) - #ssh.wait() - - # Forward data between the channel and the process - while True: - r, w, e = select.select([channel, master], [], []) - if channel in r: - data = channel.recv(1024) - if not data: - break - os.write(master, data) - if master in r: - data = os.read(master, 1024) - if not data: - break - channel.send(data) - - channel.close() - ssh.terminate() +def handle_tcp_forwarding(channel, origin, destination): + try: + sock = socket.socket() + sock.connect(destination) + while True: + r, w, x = select.select([sock, channel], [], []) + if sock in r: + data = sock.recv(1024) + if len(data) == 0: + break + channel.send(data) + if channel in r: + data = channel.recv(1024) + if len(data) == 0: + break + sock.send(data) + sock.close() + channel.close() + except Exception as e: + print(f"Forwarding error: {str(e)}") -def main(): - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(('', 2222)) - server_socket.listen(100) +def handle_client(client_sock): + try: + transport = paramiko.Transport(client_sock) + transport.add_server_key(paramiko.Ed25519Key.from_private_key_file('daemon.key')) + + server = CustomServer(None) + transport.start_server(server=server) + + # Wait for authentication to complete + server.event.wait() + + username = transport.get_username() + if username is None: + print("Error: No username retrieved") + return + + vm_port = get_vm_port(username) + if vm_port is None: + print(f"Error: No VM associated with user {username}") + return + + channel = transport.accept(20) + if channel is None: + print("Error: Channel not established") + return + + vm_transport = paramiko.Transport(("localhost", vm_port)) + vm_transport.start_client() + vm_transport.auth_password('root', '') + + vm_channel = vm_transport.open_session() + vm_channel.get_pty() + vm_channel.invoke_shell() + + while True: + r, w, x = select.select([channel, vm_channel], [], []) + if channel in r: + data = channel.recv(1024) + if len(data) == 0: + break + vm_channel.send(data) + if vm_channel in r: + data = vm_channel.recv(1024) + if len(data) == 0: + break + channel.send(data) + + channel.close() + vm_channel.close() + + except paramiko.SSHException as e: + print(f"SSH error: {str(e)}") + finally: + try: + transport.close() + vm_transport.close() + except: + pass + +def start_server(port=22, bind_address=''): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((bind_address, port)) + except Exception as e: + print(f"Error binding to port {port}: {str(e)}") + sys.exit(1) while True: - client_socket, addr = server_socket.accept() - threading.Thread(target=handle_client, args=(client_socket, addr)).start() - + try: + sock.listen(100) + print("Listening for connections...") + client, addr = sock.accept() + print(f"Got a connection from {addr}") + threading.Thread(target=handle_client, args=(client,)).start() + except Exception as e: + print(f"Error: {str(e)}") if __name__ == '__main__': - main() + start_server(port=2222)