#!/usr/bin/python3
import os
import sys
import locale
import socket
import argparse
from shutil import which
from subprocess import Popen, PIPE, run
from collections import namedtuple

sys.tracebacklimit = 0
encoding = locale.getdefaultlocale()[1]

parser = argparse.ArgumentParser(
    description='This utility uses zfs send/receive to transfer snapshots from '
                'one zfs pool/filesytem to another. If preexisting snapshots with the same '
                'name exist on both sides they are assumed to hold identical state and only '
                'incremental send and receives are done to reduce the amount of data transferred.',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    fromfile_prefix_chars='@'
)

parser.add_argument('origin', type=str,
                    help='Origin zfs filesystem, could be an SSH remote path, e.g. "user@host:pool/fs"')

parser.add_argument('destination', type=str,
                    help='Destination prefix, e.g. backup_pool/this_host. The ZFS snapshots from "origin" '
                         'are then stored below there. As with "origin", can be an SSH remote path')

parser.add_argument('--snapname', metavar='TAGNAME', type=str,
                    help='Only consider snapshots starting with TAGNAME')

parser.add_argument('-c', '--cache-dir', metavar="DIR",
                    help='First create a temporary file in directory DIR with the `zfs send` stream, '
                         'then rsync that to the remote host. With that, one can resume '
                         'a partially sent file without starting over')

parser.add_argument('-m', '--min-cache-size',
                    help='Minimum (estimated) size of stream to use the cache-dir funcitonality, default 50M')

parser.add_argument('-v', '--verbose', action="append_const", const=1,
                    help='Echo the commands that are issued. Two -v pass a -v along to the zfs send/recv commands')

parser.add_argument('-n', '--dry-run', action="store_true",
                    help='Only echo the transfer commands that would be issued,\n'
                         'do not actually send anything.')

args = parser.parse_args()
args.verbose = sum(args.verbose) if args.verbose is not None else 0

if not args.cache_dir and args.min_cache_size is not None:
    print("Cannot specify -m/--min-cache-size without -c/--cache-dir", file=sys.stderr)
    raise SystemExit(1)
elif args.cache_dir and args.min_cache_size is None:
    args.min_cache_size = "50M"

def check_returncode(proc):
    if not proc.returncode == 0:
        print("\nCommand \"{0}\" returned error {1}, aborting...".format(" ".join(proc.args), proc.returncode), file=sys.stderr)
        raise SystemExit(1)


def select_snapshots(proc, prefix):
    res = []
    for line in proc.stdout:
        snapshot, *_ = line.decode(encoding).split("\t")
        fs, snapname = snapshot.split("@", 2)
        if not fs.startswith(prefix):
            print(("Unexpexted filesystem: \"{0}\", "
                   "expected a string starting with \"{1}\"").format(fs, prefix), file=sys.stderr)
            raise SystemExit(1)
        fs = fs[len(prefix):]
        if args.snapname:
            if snapname.startswith(args.snapname):
                res.append((fs, snapname))
        else:
            res.append((fs, snapname))
    proc.wait()
    if proc.returncode == 1:
        if proc.stderr.read().endswith(b"dataset does not exist\n"):
            return []
    check_returncode(proc)
    return res


def prefix_ssh(location):
    if ":" in location:
        host, location = location.split(":")
        ssh = ["ssh", host]
    else:
        location = location
        host = ""
        ssh = []
    return ssh, host, location

ssh_dest, dest_host, destination = prefix_ssh(args.destination)
ssh_orig, orig_host, origin = prefix_ssh(args.origin)

compression = ssh_orig or ssh_dest

if args.cache_dir:
    if ":" in args.cache_dir:
        orig_cache_dir, dest_cache_dir = args.cache_dir.split(":")
    else:
        orig_cache_dir, dest_cache_dir = (args.cache_dir,) * 2
    if not (bool(ssh_orig) != bool(ssh_dest)):
        print("-c/--cache-dir is not supported for two local or two remote filesystems", file=sys.stderr)
        raise SystemExit(1)

    min_cache_size = int(args.min_cache_size.replace("k", "000").replace("M", "000000").replace("G", "000000000"))

if args.verbose > 1:
    verbose = ["-v"]
else:
    verbose = []

with Popen(ssh_dest + ["zfs", "list", "-t", "snapshot", "-r", destination, "-H"], stdout=PIPE, stderr=PIPE) as proc:
    # Split into chunks for each fs
    destinations = {}
    for fs, snapname in select_snapshots(proc, destination):
        if fs not in destinations:
            destinations[fs] = []
        destinations[fs].append(snapname)

with Popen(ssh_orig + ["zfs", "list", "-t", "snapshot", "-r", origin, "-H"], stdout=PIPE, stderr=PIPE) as proc:
    origin_snapshots = select_snapshots(proc, origin)

# Make a dictionary of all origin snapshots
# to quickly test if they exist
origin_dict = {}
for fs, snapname in origin_snapshots:
    origin_dict[fs, snapname] = True


def rsync(orig_url, dest_url):
    if args.verbose > 1:
        progress = ["--info=progress2"]
    else:
        progress = []
    with echoPopen(["rsync", "--append"] + progress + ["--inplace", orig_url, dest_url]) as rsync:
        rsync.wait()
        check_returncode(rsync)

def command_on_url(*commands, abort=True, dry_run=False, echo=False):
    *commands, url = commands
    if ":" in url:
        host, filename = url.split(":")
        ssh = ["ssh", host]
    else:
        filename = url
        ssh = []

    if echo:
        if ssh:
            print(*ssh, '"' + " ".join(commands + [filename]) + '"')
        else:
            print(*commands, filename)
        return True
    elif not dry_run:
        cmd = run(ssh + commands + [filename])
        if abort:
            check_returncode(cmd)
        else:
            return cmd.returncode == 0

def exists(url):
    return command_on_url("test", "-f", url, abort=False, dry_run=False, echo=False)

def touch(url):
    return command_on_url("touch", url, dry_run=args.dry_run, echo=args.dry_run or args.verbose)

def rm(url):
    return command_on_url("rm", "-v", url, dry_run=args.dry_run, echo=args.dry_run or args.verbose)

MockStdout = namedtuple("MockStdout", ["cmd_string"])

class echoPopen(Popen):
    def __init__(self, commands, stdin=None, stdout=None, **kwargs):
        if args.verbose or args.dry_run:
            if stdin is not None:
                in_pipe = stdin.cmd_string + " | "
            else:
                in_pipe = ""
            if commands[0] == "ssh":
                cmd = " ".join(commands[0:2]) + ' "' + " ".join(commands[2:]) + '"'
            else:
                cmd = " ".join(commands)
            cmd_string = in_pipe + cmd
            if stdout is None:
                print(cmd_string)

        if not args.dry_run:
            super().__init__(commands, stdin=stdin, stdout=stdout, **kwargs)
            if stdout is not None:
                self.stdout.cmd_string = cmd_string
        else:
            self.stdout = MockStdout(cmd_string)
            self.returncode = 0

    def wait(self):
        if not args.dry_run:
            super().wait()

    def __exit__(self, type, value, traceback):
        if not args.dry_run:
            return super().__exit__(type, value, traceback)

for fs, snapname in origin_snapshots:
    if fs in destinations and snapname in destinations[fs]:
        continue
    if fs not in destinations:
        # Send full
        last = None
        destinations[fs] = []
    else:
        # Find older common snapshot
        last = None
        for old_snapshot in destinations[fs]:
            if (fs, old_snapshot) in origin_dict:
                last = old_snapshot

    if last:
        send_cmd = ssh_orig + ["zfs", "send"] + verbose + ["-i", "@{0}".format(last), "{0}{1}@{2}".format(origin, fs, snapname)]
    else:
        send_cmd = ssh_orig + ["zfs", "send"] + verbose + ["{0}{1}@{2}".format(origin, fs, snapname)]

    # Check free space
    if args.cache_dir:
        with Popen(send_cmd + ["-nP"], stdout=PIPE) as proc:
            for line in proc.stdout:
                line = line.decode()
                if line.startswith("size"):
                    estimated_size = int(line.split()[1])

        use_cache_dir = estimated_size > min_cache_size
        if use_cache_dir and args.verbose:
            print("# Using rsync cache for this transfer")
        elif args.verbose:
            print("# Not using rsync cache for this transfer")
    else:
        use_cache_dir = False

    if use_cache_dir:
        def check_space(ssh_prefix, where, cache_dir):
            with Popen(ssh_prefix + ["df", "-B1", cache_dir], stdout=PIPE) as df:
                for line in df.stdout:
                    line = line.decode()
                df.wait()
                check_returncode(df)
                free_space = int(line.split()[3])

            if estimated_size * 1.25 > free_space:
                print("Cannot store intermediate stream at {0}, would consume "
                      "about {1} MB free space in {2}, but there is just {3} MB available".format(
                    where, estimated_size // 1000000, cache_dir, free_space // 1000000), file=sys.stderr)
                raise SystemExit(1)

        if not args.dry_run:
            check_space(ssh_orig, "origin", orig_cache_dir)
            check_space(ssh_dest, "destination", dest_cache_dir)
        cache_filename = origin + fs
        if last:
            cache_filename += "@" + last
        cache_filename += "@" + snapname

        # replace "/" and ":"
        cache_filename = cache_filename.replace("/", "_").replace(":", "_")

        dest_cache_filename = os.path.join(dest_cache_dir, cache_filename)
        orig_cache_filename = os.path.join(orig_cache_dir, cache_filename)
        orig_url = ((orig_host + ":") if orig_host else "") + orig_cache_filename
        dest_url = ((dest_host + ":") if dest_host else "") + dest_cache_filename

    if compression:
        send_cmd += ["|", "gzip"]
    if use_cache_dir:
        send_cmd += [">", orig_cache_filename]

    if compression:
        pre_pipe = ["gunzip", "|"]
    else:
        pre_pipe = []
    if use_cache_dir:
        pre_pipe = ["cat", dest_cache_filename, "|"] + pre_pipe

    recv_cmd = ssh_dest + pre_pipe + ["zfs", "recv"] + verbose + ["-u", destination + fs]

    send_shell = not ssh_orig
    if send_shell:
        send_cmd = [" ".join(send_cmd)]

    recv_shell = not ssh_dest
    if recv_shell:
        recv_cmd = [" ".join(recv_cmd)]

    # send
    if use_cache_dir:
        # Via rsync'ed stream files
        if not exists(orig_url + ".total"):
            # Create stream file
            with echoPopen(send_cmd, shell=send_shell) as sender:
                sender.wait()
                check_returncode(sender)
            touch(orig_url + ".total")
        else:
            print("# Resuming upload of partial stream file")

        rsync(orig_url, dest_url)
        recv_stdin = None
    else:
        # direct `zfs send | zfs recv` pipe
        sender = echoPopen(send_cmd, stdout=PIPE, shell=send_shell)
        recv_stdin = sender.stdout

    # recv
    with echoPopen(recv_cmd, stdin=recv_stdin, shell=recv_shell) as receiver:
        receiver.wait()
        check_returncode(receiver)

    # Cleanup
    if not use_cache_dir:
        sender.wait()
        check_returncode(sender)
    else:
        rm(orig_url + ".total")
        rm(orig_url)
        rm(dest_url)

    destinations[fs].append(snapname)
    if args.verbose or args.dry_run:
        print()

# test
for fs, snapname in origin_snapshots:
    assert(fs in destinations)
    assert(snapname in destinations[fs])