#!/usr/bin/env python3
import argparse
import dataclasses
import json
import os
import os.path
import pathlib
import shlex
import socket
import stat
import subprocess
import sys
import time
from types import TracebackType

import psutil


original_excepthook = sys.excepthook

# ignore traceback unless we set debug
def exception_handler(exception_type: type[BaseException],
                      exception: BaseException,
                      traceback: TracebackType | None) -> None:
  if args and args.debug:
    original_excepthook(exception_type, exception, traceback)
  else:
    print(f"{exception_type.__name__}: {exception}")

sys.excepthook = exception_handler


def boolean(arg: str) -> bool:
  arg = arg.lower()
  if arg == 'true':
    return True
  elif arg == 'false':
    return False
  else:
    raise ValueError(f"{arg!r} is not 'true' or 'false'.")


parser = argparse.ArgumentParser(
  prog='wprs',
  description='A launcher/control script for wprsc.')

parser.add_argument('--pulseaudio-forwarding',
                    type=boolean,
                    choices=[True, False],
                    default='true')
parser.add_argument('--wprsc-path',
                    default='wprsc')
parser.add_argument('--wprsc-wayland-debug',
                    type=boolean,
                    choices=[True, False],
                    default='false')
parser.add_argument('--wprsc-args',
                    type=shlex.split,
                    default='')
parser.add_argument('--additional-ssh-tunnel-args',
                    type=shlex.split,
                    default='')
parser.add_argument('--additional-ssh-command-args',
                    type=shlex.split,
                    default='')
parser.add_argument('--additional-command-env-vars',
                    default='')
parser.add_argument('--command-wayland-debug',
                    type=boolean,
                    choices=[True, False],
                    default='false')
parser.add_argument('--wprsd-wayland-display',
                    default='wprs-0',
                    help='The WAYLAND_DISPLAY wprsd is listening on.')
parser.add_argument('--xwayland',
                    type=boolean,
                    choices=[True, False],
                    default='true')
parser.add_argument('--wprsd-xwayland-display',
                    default=':100',
                    help='The DISPLAY wprsd is listening on.')
parser.add_argument('--print-stacktrace',
                    type=boolean,
                    choices=[True,False],
                    default='false',
                    dest='debug',
                    help='Adds python stacktrace context to errors')

title_prefix_group = parser.add_mutually_exclusive_group()
title_prefix_group.add_argument('--title-prefix',
                    type=str,
                    help='Prefix window titles with a string.',
                    default='')
title_prefix_group.add_argument('--title-prefix-hostname',
                    type=boolean,
                    help='Prefix window titles with the remote hostname.',
                    default=False)
title_prefix_group.add_argument('--title-prefix-fqdn',
                    type=boolean,
                    help='prefix window titles with the remote fqdn.',
                    default=False)

parser.add_argument('destination')

subparsers = parser.add_subparsers(title='subcommands', dest='subcommands')

parser_attach = subparsers.add_parser(
  'attach',
  help='Start wprsc and connect to a remote wprsd session.')
parser_detach = subparsers.add_parser(
  'detach',
  help='Stop wrpsc. The remote wprsd session will persist.')

parser_run = subparsers.add_parser(
  'run',
  help='Run a remote application under wprs, attaching if necessary.')
parser_run.add_argument('remote_command')
parser_run.add_argument('argument', nargs='*')

parser_restart_wprsd = subparsers.add_parser(
  'restart-wprsd',
  help=('Restart the remote wprsd, useful if it is stuck or in a bad state. '
        'This will terminate any remote applications running against wprsd.'))


def xdg_runtime_dir() -> str | None:
  return os.getenv('XDG_RUNTIME_DIR')


def socket_dir() -> str:
  return xdg_runtime_dir() or os.getenv('TMPDIR') or '/tmp'


WPRS_SOCKET = os.path.join(socket_dir(), 'wprs.sock')
WPRS_CONTROL_SOCKET = os.path.join(socket_dir(), 'wprsc-ctrl.sock')


SSH_COMMON_ARGS = [
  'ssh',
  '-o', 'ControlMaster=auto',
  '-o', f'ControlPath={socket_dir()}/ssh/wprs-%C',
  '-o', 'ControlPersist=yes',
]


def start_ssh_tunnel() -> None:
  ssh_socket_dir = os.path.join(socket_dir(), 'ssh')
  os.makedirs(ssh_socket_dir, exist_ok=True)
  os.chmod(ssh_socket_dir, stat.S_IRUSR|stat.S_IWUSR|stat.S_IXUSR)
  cmd = (SSH_COMMON_ARGS +
         ['-f', '-N', '-T'] +
         args.additional_ssh_tunnel_args +
         [args.destination])
  print(f'Starting SSH tunnel: {cmd!r}')
  subprocess.run(cmd, env=os.environ, check=True, start_new_session=True)


def stop_ssh_tunnel() -> None:
  cmd = (SSH_COMMON_ARGS +
         ['-O', 'exit'] +
         [args.destination])
  print(f'Stopping SSH tunnel: {cmd!r}')
  subprocess.run(cmd, env=os.environ)
  try:
    os.unlink(WPRS_SOCKET)
    os.unlink(WPRS_CONTROL_SOCKET)
  except FileNotFoundError:
    pass


def check_ssh_tunnel() -> bool:
  cmd = (SSH_COMMON_ARGS +
         ['-O', 'check'] +
         [args.destination])
  print(f'Checking SSH tunnel: {cmd!r}')
  return subprocess.run(cmd, env=os.environ).returncode == 0


def remote_env_var(var: str) -> str:
  return run_remote_command_with_stdout(['sh', '-c', f'"echo ${var}"'])

def run_remote_command_with_stdout(remote_cmd: list[str]) -> str:
  cmd = (SSH_COMMON_ARGS +
        [args.destination] + remote_cmd)
  print(f'Running remote command with stdout: {remote_cmd!r}')
  return subprocess.run(
    cmd,
    env=os.environ,
    check=True,
    stdout=subprocess.PIPE,
    text=True).stdout.strip()

def remote_socket_dir() -> str:
  return remote_env_var('XDG_RUNTIME_DIR') or remote_env_var('TEMPDIR') or '/tmp'


def forward_wprs_sock() -> None:
  cmd = (SSH_COMMON_ARGS +
         ['-O', 'forward',
          '-L', f'{WPRS_SOCKET}:{remote_socket_dir()}/wprs.sock',
          '-L', f'{WPRS_CONTROL_SOCKET}:{remote_socket_dir()}/wprsc-ctrl.sock']
         + [args.destination])
  print(f'Forwarding wprs sockets: {cmd!r}')
  subprocess.run(cmd, env=os.environ, check=True)


def pulse_socket() -> str:
  pulse_server = os.getenv('PULSE_SERVER')
  if pulse_server is not None:
    if pulse_server.startswith('unix:'):
      return pulse_server.removeprefix('unix:')
    else:
      raise RuntimeError(f'PulseAudio server path {pulse_server} not supported.')
  else:
    xrd = xdg_runtime_dir()
    if xrd is None:
      raise RuntimeError(f'PULSE_SERVER and XDG_RUNTIME_DIR are both unset.')
    else:
      return f'{xrd}/pulse/native'


def forward_pulse_sock() -> None:
  cmd = (SSH_COMMON_ARGS +
         ['-O', 'forward',
          '-R', f'{remote_socket_dir()}/wprs-pulse:{pulse_socket()}']
         + [args.destination])
  print(f'Forwarding pulseaudio socket: {cmd!r}')
  subprocess.run(cmd, env=os.environ, check=True)


def create_ssh_auth_sock_symlink() -> None:
  cmd = (SSH_COMMON_ARGS +
         [args.destination,
          'sh', '-c', f'"ln -sf $SSH_AUTH_SOCK {remote_socket_dir()}/wprs-ssh-auth.sock"'])
  print(f'Creating SSH_AUTH_SOCK symlink: {cmd!r}')
  subprocess.run(cmd, env=os.environ, check=True)


WPRS_PIDFILE = os.path.join(socket_dir(), 'wprsc.pid')


def wprsc_pid() -> int | None:
  try:
    with open(WPRS_PIDFILE, 'r') as f:
      return int(f.read().strip())
  except (FileNotFoundError, ValueError):
    return None


def wprsc_proc() -> psutil.Process | None:
  pid = wprsc_pid()
  if pid is None:
    return None

  try:
    return  psutil.Process(pid)
  except psutil.NoSuchProcess:
    return None


def wprsc_env(wayland_debug: bool) -> dict[str, str]:
  return {'WAYLAND_DEBUG': str(int(wayland_debug)),
          'RUST_BACKTRACE': '1'}


def start_wprsc(cmd: list[str], wayland_debug: bool) -> None:
  env = os.environ | wprsc_env(wayland_debug)
  print(f'Executing wprsc: {cmd!r}')
  proc = subprocess.Popen(cmd, env=env, start_new_session=True)
  with open(WPRS_PIDFILE, 'w') as f:
    f.write(str(proc.pid))


def stop_wprsc() -> None:
  if (proc := wprsc_proc()) is not None:
    print(f'Stopping wprsc ({proc.pid})')
    proc.terminate()

  try:
    os.unlink(WPRS_PIDFILE)
  except FileNotFoundError:
    pass


@dataclasses.dataclass
class Response:
  status: str
  payload: str

  @classmethod
  def from_json(cls, s: str) -> 'Response':
    return cls(**json.loads(s))

  def is_ok(self):
    return self.status == 'Ok'

  def payload_if_ok(self) -> str:
    if self.is_ok():
      return self.payload
    else:
      raise RuntimeError(f'Response was Err: {self.payload}')


@dataclasses.dataclass
class Capabilities:
  xwayland: bool

  @classmethod
  def from_json(cls, s: str) -> 'Capabilities | None':
    d = json.loads(s)
    return cls(**d) if d is not None else None


def query_capabilities() -> Capabilities | None:
  # wait for wprsc to start and create the control socket
  i = 0
  while True:
    try:
      with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
        s.connect(WPRS_CONTROL_SOCKET)
        with s.makefile('rw') as f:
          f.write('caps\n')
          f.flush()
          resp = Response.from_json(f.readline().strip())
          return Capabilities.from_json(resp.payload_if_ok())
    except ConnectionError:
      if i == 10:
        raise
      i+=1
      time.sleep(1)

def get_title_prefix() -> [str]:
  param = '--title-prefix='
  if args.title_prefix:
    return [f'{param}{args.title_prefix}: ']
  elif args.title_prefix_hostname:
    return [f"{param}{run_remote_command_with_stdout(['hostname', '-f'])}: "]
  elif args.title_prefix_fqdn:
    return [f"{param}{run_remote_command_with_stdout(['hostname', '-s'])}: "]
  else:
    return []

def maybe_start_wprsc() -> Capabilities | None:
  cmd = [args.wprsc_path] + args.wprsc_args + get_title_prefix()
  should_start_wprsc = False
  proc = wprsc_proc()

  if proc is None:
    should_start_wprsc = True
  else:
    if proc.cmdline() != cmd:
      should_start_wprsc = True

    wprsc_proc_env = proc.environ()
    for k, v in wprsc_env(args.wprsc_wayland_debug).items():
      if wprsc_proc_env.get(k) != v:
        should_start_wprsc = True

  if should_start_wprsc:
    stop_wprsc()
    start_wprsc(cmd, args.wprsc_wayland_debug)

  return query_capabilities()


def run_remote_command(cmd: list[str], env: dict[str, str]) -> None:
  env_cmd = (
    ['env'] +
    [f'{k}={v}' for k, v in env.items()] +
    args.additional_command_env_vars.split()
  )

  cmd = (
    SSH_COMMON_ARGS +
    [args.destination] +
    env_cmd +
    cmd
  )

  print(f'Executing remote command: {cmd!r}')
  subprocess.run(cmd, env=os.environ)

def get_x_cursor_size() -> str:
  default = '24'

  if xcurser_size := os.getenv('XCURSOR_SIZE'):
    return xcurser_size

  cmd = ['xrdb', '-query', '-get', 'Xcursor.size']
  print(f'Executing local command: {cmd!r}')

  try:
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=sys.stderr, close_fds=True)
  except FileNotFoundError:
    return default

  timeout_in_secs = 5
  try:
    (stdout, _) = proc.communicate(timeout=timeout_in_secs)
    if stdout:
      return stdout.strip()

  except subprocess.TimeoutExpired:
    proc.kill()
    (stdout, _) = proc.communicate()
  except Exception:
    proc.terminate()

  return default

def start_remote_command(caps: Capabilities | None) -> None:
  env = {
    'WAYLAND_DEBUG': str(int(args.command_wayland_debug)),
    'WAYLAND_DISPLAY': args.wprsd_wayland_display,
    'SSH_AUTH_SOCK': f'{remote_socket_dir()}/wprs-ssh-auth.sock',
    'XCURSOR_SIZE': get_x_cursor_size(),
  }

  if args.xwayland:
    if caps and caps.xwayland:
      env['DISPLAY'] = args.wprsd_xwayland_display
    else:
      print('WARNING: xwaland requested but wprsd has xwayland disabled.',
            file=sys.stderr)

  if args.pulseaudio_forwarding:
    env['PULSE_SERVER'] = f'unix:{remote_socket_dir()}/wprs-pulse'

  cmd = [args.remote_command] + args.argument

  # TODO: maybe make ctrl+c kill the remote process.
  run_remote_command(cmd, env)


def forward_sockets() -> None:
  forward_wprs_sock()
  if args.pulseaudio_forwarding:
    forward_pulse_sock()


def attach() -> Capabilities | None:
  if not check_ssh_tunnel():
    start_ssh_tunnel()

  # Sometimes forwarding the sockets doesn't work and restarting the ssh
  # connection helps. TODO: try to make this less flaky.
  try:
    forward_sockets()
  except subprocess.CalledProcessError:
    detach()
    start_ssh_tunnel()
    forward_sockets()

  create_ssh_auth_sock_symlink()

  return maybe_start_wprsc()


def detach() -> None:
  stop_wprsc()
  stop_ssh_tunnel()


def run() -> None:
  start_remote_command(attach())


def restart_wprsd() -> None:
  run_remote_command(['systemctl', '--user', 'restart', 'wprsd.service'], {})
  detach()


parser_attach.set_defaults(func=attach)
parser_detach.set_defaults(func=detach)
parser_run.set_defaults(func=run)
parser_restart_wprsd.set_defaults(func=restart_wprsd)

args = parser.parse_args()
print(f'Args: {args}')
if not args.subcommands:
  raise parser.error("A subcommand must be provided.")
args.func()
