#!/usr/bin/python3
#
# A tool to test the effect (number of pgs, objects, bytes moved) of a
# crushmap change. This is a wrapper around osdmaptool, hardly relying
# on its --test-map-pgs-dump option to get the list of changed pgs.
# Additionally it uses pg stats to calculate the numbers of objects
# and bytes moved.
#
# Typical usage:
#
# # Get current crushmap
# $ crushdiff export cm.txt
# # Edit the map
# $ $EDITOR cm.txt
# # Check the result
# $ crushdiff compare cm.txt
# # Install the updated map
# $ crushdiff import cm.txt
#
# By default, crushdiff will use the cluster current osdmap and pg
# stats, which requires access to the cluster. But one can use the
# --osdmap and --pg-dump options to test against previously obtained
# data.
#

import argparse
import re
import json
import os
import sys
import tempfile

#
# Global
#

parser = argparse.ArgumentParser(prog='crushdiff',
                                 description='Tool for updating crush map')
parser.add_argument(
    'command',
    metavar='compare|export|import',
    help='command',
    default=None,
)
parser.add_argument(
    '-c', '--compiled',
    action='store_true',
    help='use compiled crush map',
    default=False,
)
parser.add_argument(
    'crushmap',
    metavar='crushmap',
    help='crushmap json file',
    default=None,
)
parser.add_argument(
    '-m', '--osdmap',
    metavar='osdmap',
    help='',
    default=None,
)
parser.add_argument(
    '-p', '--pg-dump',
    metavar='pg-dump',
    help='`ceph pg dump` json output',
    default=None,
)
parser.add_argument(
    '-v', '--verbose',
    action='store_true',
    help='be verbose',
    default=False,
)

#
# Functions
#

def get_human_readable(bytes, precision=2):
    suffixes = ['', 'Ki', 'Mi', 'Gi', 'Ti']
    suffix_index = 0
    while bytes > 1024 and suffix_index < 4:
        # increment the index of the suffix
        suffix_index += 1
        # apply the division
        bytes = bytes / 1024.0
    return '%.*f%s' % (precision, bytes, suffixes[suffix_index])

def run_cmd(cmd, verbose=False):
    if verbose:
        print(cmd, file=sys.stderr, flush=True)
    os.system(cmd)

def get_osdmap(file):
    with open(file, "r") as f:
        return json.load(f)

def get_pools(osdmap):
    return {p['pool']: p for p in osdmap['pools']}

def get_erasure_code_profiles(osdmap):
    return osdmap['erasure_code_profiles']

def get_pgmap(pg_dump_file):
    with open(pg_dump_file, "r") as f:
        dump = json.load(f)
        return dump.get('pg_map', dump)

def get_pg_stats(pgmap):
    return {pg['pgid']: pg for pg in pgmap['pg_stats']}

def parse_test_map_pgs_dump(file):
    # Format:
    # pool 1 pg_num 16
    # 1.0	[1,0,2]	1
    # 1.1	[2,0,1]	2
    # ...
    # pool 2 pg_num 32
    # 2.0	[2,1,0]	2
    # 2.1	[2,1,0]	2
    # ...
    # #osd	count	first	primary	c wt	wt
    # osd.1	208	123	123	0.098587	1

    pgs = {}

    with open(file, "r") as f:
        pool = None
        for l in f.readlines():
            m = re.match('^pool (\d+) pg_num (\d+)', l)
            if m:
                pool = m.group(1)
                continue
            if not pool:
                continue
            m = re.match('^#osd', l)
            if m:
                break
            m = re.match('^(\d+\.[0-9a-f]+)\s+\[([\d,]+)\]', l)
            if not m:
                continue
            pgid = m.group(1)
            osds = [int(x) for x in m.group(2).split(',')]
            pgs[pgid] = osds

    return pgs

def do_compare(new_crushmap_in, osdmap=None, pg_dump=None, compiled=False,
               verbose=False):
    with tempfile.TemporaryDirectory() as tmpdirname:
        if compiled:
            new_crushmap_file = new_crushmap_in
        else:
            new_crushmap_file = os.path.join(tmpdirname, 'crushmap')
            run_cmd('crushtool -c {} -o {}'.format(new_crushmap_in,
                                                   new_crushmap_file), verbose)

        osdmap_file = os.path.join(tmpdirname, 'osdmap')
        if osdmap:
            run_cmd('cp {} {}'.format(osdmap, osdmap_file), verbose)
        else:
            run_cmd('ceph osd getmap -o {}'.format(osdmap_file), verbose)

        if not pg_dump:
            pg_dump = os.path.join(tmpdirname, 'pg_dump.json')
            run_cmd('ceph pg dump --format json > {}'.format(pg_dump), verbose)

        old_test_map_pgs_dump = os.path.join(tmpdirname, 'pgs.old.txt')
        run_cmd('osdmaptool {} --test-map-pgs-dump > {}'.format(
            osdmap_file, old_test_map_pgs_dump), verbose)
        if verbose:
            run_cmd('cat {} >&2'.format(old_test_map_pgs_dump), True)

        new_test_map_pgs_dump = os.path.join(tmpdirname, 'pgs.new.txt')
        run_cmd(
            'osdmaptool {} --import-crush {} --test-map-pgs-dump > {}'.format(
                osdmap_file, new_crushmap_file, new_test_map_pgs_dump), verbose)
        if verbose:
            run_cmd('cat {} >&2'.format(new_test_map_pgs_dump), True)

        osdmap_file_json = os.path.join(tmpdirname, 'osdmap.json')
        run_cmd('osdmaptool {} --dump json > {}'.format(
            osdmap_file, osdmap_file_json), verbose)
        osdmap = get_osdmap(osdmap_file_json)
        pools = get_pools(osdmap)
        ec_profiles = get_erasure_code_profiles(osdmap)

        pgmap = get_pgmap(pg_dump)
        pg_stats = get_pg_stats(pgmap)

        old_pgs = parse_test_map_pgs_dump(old_test_map_pgs_dump)
        new_pgs = parse_test_map_pgs_dump(new_test_map_pgs_dump)

    diff_pg_count = 0
    total_object_count = 0
    diff_object_count = 0
    for pgid in old_pgs:
        objects = pg_stats[pgid]['stat_sum']['num_objects']
        total_object_count += objects

        if old_pgs[pgid] == new_pgs[pgid]:
            continue

        pool_id = int(pgid.split('.')[0])

        if len(new_pgs[pgid]) < pools[pool_id]['size']:
            print("WARNING: {} will be undersized ({})".format(
                pgid, new_pgs[pgid]), file=sys.stderr, flush=True)

        if not pools[pool_id]['erasure_code_profile'] and \
           sorted(old_pgs[pgid]) == sorted(new_pgs[pgid]):
            continue

        if verbose:
            print("{}\t{} -> {}".format(pgid, old_pgs[pgid], new_pgs[pgid]),
                  file=sys.stderr, flush=True)
        diff_pg_count += 1
        diff_object_count += objects

    print("{}/{} ({:.2f}%) pgs affected".format(
        diff_pg_count, len(old_pgs),
        100 * diff_pg_count / len(old_pgs) if len(old_pgs) else 0),
        flush=True)
    print("{}/{} ({:.2f}%) objects affected".format(
        diff_object_count, total_object_count,
        100 * diff_object_count / total_object_count \
        if total_object_count else 0), flush=True)

    total_pg_shard_count = 0
    diff_pg_shard_count = 0
    total_object_shard_count = 0
    diff_object_shard_count = 0
    total_bytes = 0
    diff_bytes = 0
    for pgid in old_pgs:
        pool_id = int(pgid.split('.')[0])
        ec_profile = pools[pool_id]['erasure_code_profile']
        if ec_profile:
            k = int(ec_profiles[ec_profile]['k'])
            m = int(ec_profiles[ec_profile]['m'])
        else:
            k = 1
            m = pools[pool_id]['size'] - 1

        bytes = pg_stats[pgid]['stat_sum']['num_bytes'] + \
            pg_stats[pgid]['stat_sum']['num_omap_bytes']
        objects = pg_stats[pgid]['stat_sum']['num_objects']

        total_pg_shard_count += len(old_pgs[pgid])
        total_object_shard_count += objects * (k + m)
        total_bytes += bytes * (k + m) / k

        if old_pgs[pgid] == new_pgs[pgid]:
            continue

        old_count = diff_pg_shard_count

        if ec_profile:
            for i in range(len(old_pgs[pgid])):
                if old_pgs[pgid][i] != new_pgs[pgid][i]:
                    diff_pg_shard_count += 1
                    diff_object_shard_count += objects
                    diff_bytes += bytes / k
        else:
            for osd in old_pgs[pgid]:
                if osd not in new_pgs[pgid]:
                    diff_pg_shard_count += 1
                    diff_object_shard_count += objects
                    diff_bytes += bytes / k

        if old_count == diff_pg_shard_count:
            continue

        if verbose:
            print("{}\t{} -> {}".format(pgid, old_pgs[pgid], new_pgs[pgid]),
                  file=sys.stderr, flush=True)

    print("{}/{} ({:.2f}%) pg shards to move".format(
        diff_pg_shard_count, total_pg_shard_count,
        100 * diff_pg_shard_count / total_pg_shard_count \
        if total_pg_shard_count else 0), flush=True)
    print("{}/{} ({:.2f}%) pg object shards to move".format(
        diff_object_shard_count, total_object_shard_count,
        100 * diff_object_shard_count / total_object_shard_count \
        if total_object_shard_count else 0), flush=True)
    print("{}/{} ({:.2f}%) bytes to move".format(
        get_human_readable(int(diff_bytes)),
        get_human_readable(int(total_bytes)),
        100 * diff_bytes / total_bytes if total_bytes else 0),
        flush=True)

def do_export(crushmap_out, osdmap_file=None, compiled=False, verbose=False):
    with tempfile.TemporaryDirectory() as tmpdirname:
        if not osdmap_file:
            osdmap_file = os.path.join(tmpdirname, 'osdmap')
            run_cmd('ceph osd getmap -o {}'.format(osdmap_file), verbose)

        crushmap_file = crushmap_out if compiled else \
            os.path.join(tmpdirname, 'crushmap')
        run_cmd('osdmaptool {} --export-crush {}'.format(
            osdmap_file, crushmap_file), verbose)
        if not compiled:
            run_cmd('crushtool -d {} -o {}'.format(crushmap_file, crushmap_out),
                    verbose)

def do_import(crushmap_in, osdmap=None, compiled=False, verbose=False):
    with tempfile.TemporaryDirectory() as tmpdirname:
        if compiled:
            crushmap_file = crushmap_in
        else:
            crushmap_file = os.path.join(tmpdirname, 'crushmap')
            run_cmd('crushtool -c {} -o {}'.format(crushmap_in,
                                                   crushmap_file), verbose)
        if osdmap:
            run_cmd('osdmaptool {} --import-crush {}'.format(
                osdmap, crushmap_file), verbose)
        else:
            run_cmd('ceph osd setcrushmap -i {}'.format(crushmap_file), verbose)

def main():
    args = parser.parse_args()

    if args.command == 'compare':
        do_compare(args.crushmap, args.osdmap, args.pg_dump, args.compiled,
                   args.verbose)
    elif args.command == 'export':
        do_export(args.crushmap, args.osdmap, args.compiled, args.verbose)
    elif args.command == 'import':
        do_import(args.crushmap, args.osdmap, args.compiled, args.verbose)

#
# main
#

main()
