import os
import sys
import locale
import argparse
from subprocess import Popen, PIPE, run
from collections import namedtuple

sys.tracebacklimit = 0
encoding = locale.getdefaultlocale()[1]

parser = argparse.ArgumentParser(
    description='This utility uses zfs send/receive to transfer snapshots from '
                'one zfs pool/filesytem to another. If preexisting snapshots with the same '
                'name exist on both sides they are assumed to hold identical state and only '
                'incremental send and receives are done to reduce the amount of data transferred.',

parser.add_argument('origin', type=str,
                    help='Origin zfs filesystem, could be an SSH remote path, e.g. "user@host:pool/fs"')

parser.add_argument('destination', type=str,
                    help='Destination prefix, e.g. backup_pool/this_host. The ZFS snapshots from "origin" '
                         'are then stored below there. As with "origin", can be an SSH remote path')

parser.add_argument('--snapname', metavar='TAGNAME', type=str, action='append',
                    help='Only consider snapshot names starting with TAGNAME. '
                         'Can be specified more than once')

parser.add_argument('--ignore-snapname', metavar='TAGNAME', type=str, action='append',
                    help='Do not consider snapshot names starting with TAGNAME. '
                         'Can be specified more than once')

parser.add_argument('-c', '--cache-dir', metavar="DIR",
                    help='First create a temporary file in directory DIR with the `zfs send` stream, '
                         'then rsync that to the remote host. With that, one can resume '
                         'a partially sent file without starting over')

parser.add_argument('-m', '--min-cache-size',
                    help='Minimum (estimated) size of stream to use the cache-dir funcitonality, default 50M')

parser.add_argument('-z', '--compression', action="store_true",
                    help='Filter network streams through gzip/gunzip')

parser.add_argument('-v', '--verbose', action="append_const", const=1,
                    help='Echo the commands that are issued. Two -v pass a -v along to the zfs send/receive commands')

parser.add_argument('-n', '--dry-run', action="store_true",
                    help='Only echo the transfer commands that would be issued,\n'
                         'do not actually send anything.')

parser.add_argument('-F', '--force', action="store_true",
                    help='Pass `-F` to `zfs receive` to overwrite other '
                         'snapshots or diverged changes on the remote side\n')

args = parser.parse_args()
args.verbose = sum(args.verbose) if args.verbose is not None else 0

if not args.cache_dir and args.min_cache_size is not None:
    print("Cannot specify -m/--min-cache-size without -c/--cache-dir", file=sys.stderr)
    raise SystemExit(1)
elif args.cache_dir and args.min_cache_size is None:
    args.min_cache_size = "50M"

def check_returncode(proc, verbose=True):
    if not proc.returncode == 0:
        if verbose:
            print("\nCommand \"{0}\" returned error {1}, aborting...".format(" ".join(proc.args), proc.returncode), file=sys.stderr)
        raise SystemExit(proc.returncode)

def select_snapshots(proc, prefix):
    res = []
    for line in proc.stdout:
        snapshot, *_ = line.decode(encoding).split("\t")
        fs, snapname = snapshot.split("@", 2)
        if not fs.startswith(prefix):
            print(("Unexpexted filesystem: \"{0}\", "
                   "expected a string starting with \"{1}\"").format(fs, prefix), file=sys.stderr)
            raise SystemExit(1)
        fs = fs[len(prefix):]
        ignore = False
        if args.ignore_snapname:
            for ignore_snapname in args.ignore_snapname:
                if snapname.startswith(ignore_snapname):
                    ignore = True
        if ignore:
        if args.snapname:
            for allowed_snapname in args.snapname:
                if snapname.startswith(allowed_snapname):
                    res.append((fs, snapname))
            res.append((fs, snapname))
    if proc.returncode == 1:
        if proc.stderr.read().endswith(b"dataset does not exist\n"):
            return []
    check_returncode(proc, verbose=os.isatty(2))
    return res

def prefix_ssh(location):
    if ":" in location:
        host, location = location.split(":")
        ssh = ["ssh", host]
        location = location
        host = ""
        ssh = []
    return ssh, host, location

ssh_dest, dest_host, destination = prefix_ssh(args.destination)
ssh_orig, orig_host, origin = prefix_ssh(args.origin)

if args.cache_dir:
    if ":" in args.cache_dir:
        orig_cache_dir, dest_cache_dir = args.cache_dir.split(":")
        orig_cache_dir, dest_cache_dir = (args.cache_dir,) * 2
    if not (bool(ssh_orig) != bool(ssh_dest)):
        print("-c/--cache-dir is not supported for two local or two remote filesystems", file=sys.stderr)
        raise SystemExit(1)

    min_cache_size = int(args.min_cache_size.replace("k", "000").replace("M", "000000").replace("G", "000000000"))

if args.verbose > 1:
    verbose = ["-v"]
    verbose = []

if args.force:
    force = ["-F"]
    force = []

with Popen(ssh_dest + ["zfs", "list", "-t", "snapshot", "-r", destination, "-H"], stdout=PIPE, stderr=PIPE) as proc:
    # Split into chunks for each fs
    destinations = {}
    for fs, snapname in select_snapshots(proc, destination):
        if fs not in destinations:
            destinations[fs] = []

with Popen(ssh_orig + ["zfs", "list", "-t", "snapshot", "-r", origin, "-H"], stdout=PIPE, stderr=PIPE) as proc:
    origin_snapshots = select_snapshots(proc, origin)

# Make a dictionary of all origin snapshots
# to quickly test if they exist
origin_dict = {}
for fs, snapname in origin_snapshots:
    origin_dict[fs, snapname] = True

def rsync(orig_url, dest_url):
    if args.verbose > 1:
        progress = ["--info=progress2"]
        progress = []
    with echoPopen(["rsync", "--append"] + progress + ["--inplace", orig_url, dest_url]) as rsync:

def command_on_url(*commands, abort=True, dry_run=False, echo=False):
    *commands, url = commands
    if ":" in url:
        host, filename = url.split(":")
        ssh = ["ssh", host]
        filename = url
        ssh = []

    if echo:
        if ssh:
            print(*ssh, '"' + " ".join(commands + [filename]) + '"')
            print(*commands, filename)

    if not dry_run:
        cmd = run(ssh + commands + [filename])
        if abort:
            return cmd.returncode == 0

    return True

def exists(url):
    return command_on_url("test", "-f", url, abort=False, dry_run=False, echo=False)

def touch(url):
    return command_on_url("touch", url, dry_run=args.dry_run, echo=args.dry_run or args.verbose)

def rm(url):
    return command_on_url("rm", "-v", url, dry_run=args.dry_run, echo=args.dry_run or args.verbose)

MockStdout = namedtuple("MockStdout", ["cmd_string"])

class echoPopen(Popen):
    def __init__(self, commands, stdin=None, stdout=None, **kwargs):
        if stdin is not None:
            in_pipe = stdin.cmd_string + " | "
            in_pipe = ""
        if commands[0] == "ssh":
            cmd = " ".join(commands[0:2]) + ' "' + " ".join(commands[2:]) + '"'
            cmd = " ".join(commands)
        cmd_string = in_pipe + cmd

        if args.verbose or args.dry_run:
            if stdout is None:

        if not args.dry_run:
            super().__init__(commands, stdin=stdin, stdout=stdout, **kwargs)
            if stdout is not None:
                self.stdout.cmd_string = cmd_string
            self.stdout = MockStdout(cmd_string)
            self.returncode = 0

    def wait(self):
        if not args.dry_run:

    def __exit__(self, type, value, traceback):
        if not args.dry_run:
            return super().__exit__(type, value, traceback)

for fs, snapname in origin_snapshots:
    if fs in destinations and snapname in destinations[fs]:
    if fs not in destinations:
        # Send full
        last = None
        destinations[fs] = []
        # Find older common snapshot
        last = None
        for old_snapshot in destinations[fs]:
            if (fs, old_snapshot) in origin_dict:
                last = old_snapshot

    if last:
        send_cmd = ssh_orig + ["zfs", "send"] + verbose + ["-i", "@{0}".format(last), "{0}{1}@{2}".format(origin, fs, snapname)]
        send_cmd = ssh_orig + ["zfs", "send"] + verbose + ["{0}{1}@{2}".format(origin, fs, snapname)]

    # Check free space
    if args.cache_dir:
        with Popen(send_cmd + ["-nP"], stdout=PIPE) as proc:
            for line in proc.stdout:
                line = line.decode()
                if line.startswith("size"):
                    estimated_size = int(line.split()[1])

        use_cache_dir = estimated_size > min_cache_size
        if use_cache_dir and args.verbose:
            print("# Using rsync cache for this transfer")
        elif args.verbose:
            print("# Not using rsync cache for this transfer")
        use_cache_dir = False

    if use_cache_dir:
        def check_space(ssh_prefix, where, cache_dir):
            with Popen(ssh_prefix + ["df", "-B1", cache_dir], stdout=PIPE) as df:
                for line in df.stdout:
                    line = line.decode()
                free_space = int(line.split()[3])

            if estimated_size * 1.25 > free_space:
                print("Cannot store intermediate stream at {0}, would consume "
                      "about {1} MB free space in {2}, but there is just {3} MB available".format(
                          where, estimated_size // 1000000, cache_dir, free_space // 1000000), file=sys.stderr)
                raise SystemExit(1)

        if not args.dry_run:
            check_space(ssh_orig, "origin", orig_cache_dir)
            check_space(ssh_dest, "destination", dest_cache_dir)
        cache_filename = origin + fs
        if last:
            cache_filename += "@" + last
        cache_filename += "@" + snapname

        # replace "/" and ":"
        cache_filename = cache_filename.replace("/", "_").replace(":", "_")

        dest_cache_filename = os.path.join(dest_cache_dir, cache_filename)
        orig_cache_filename = os.path.join(orig_cache_dir, cache_filename)
        orig_url = ((orig_host + ":") if orig_host else "") + orig_cache_filename
        dest_url = ((dest_host + ":") if dest_host else "") + dest_cache_filename

    if args.compression:
        send_cmd += ["|", "gzip"]
    if use_cache_dir:
        send_cmd += [">", orig_cache_filename]

    if args.compression:
        pre_pipe = ["gunzip", "|"]
        pre_pipe = []
    if use_cache_dir:
        pre_pipe = ["cat", dest_cache_filename, "|"] + pre_pipe

    receive_cmd = ssh_dest + pre_pipe + ["zfs", "receive"] + force + verbose + ["-u", destination + fs]

    send_shell = not ssh_orig
    if send_shell:
        send_cmd = [" ".join(send_cmd)]

    receive_shell = not ssh_dest
    if receive_shell:
        receive_cmd = [" ".join(receive_cmd)]

    # send
    if use_cache_dir:
        # Via rsync'ed stream files
        if not exists(orig_url + ".total"):
            # Create stream file
            with echoPopen(send_cmd, shell=send_shell) as sender:
            touch(orig_url + ".total")
            print("# Resuming upload of partial stream file")

        rsync(orig_url, dest_url)
        receive_stdin = None
        # direct `zfs send | zfs receive` pipe
        sender = echoPopen(send_cmd, stdout=PIPE, shell=send_shell)
        receive_stdin = sender.stdout

    # receive
    with echoPopen(receive_cmd, stdin=receive_stdin, shell=receive_shell) as receiver:

    # Cleanup
    if not use_cache_dir:
        rm(orig_url + ".total")

    if args.verbose or args.dry_run:

# test
for fs, snapname in origin_snapshots:
    assert(fs in destinations)
    assert(snapname in destinations[fs])