# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import os.path as osp
import argparse
from collections import deque
from functools import reduce

import paddlers
import numpy as np
import cv2
try:
    from osgeo import gdal
except:
    import gdal
from tqdm import tqdm

from utils import time_it

IGN_CLS = 255
FMT = "im_{idx}{ext}"


class QuadTreeNode(object):
    def __init__(self, i, j, h, w, level, cls_info=None):
        super().__init__()
        self.i = i
        self.j = j
        self.h = h
        self.w = w
        self.level = level
        self.cls_info = cls_info
        self.reset_children()

    @property
    def area(self):
        return self.h * self.w

    @property
    def is_bg_node(self):
        return self.cls_info is None

    @property
    def coords(self):
        return (self.i, self.j, self.h, self.w)

    def get_cls_cnt(self, cls):
        if self.cls_info is None or cls >= len(self.cls_info):
            return 0
        return self.cls_info[cls]

    def get_children(self):
        for child in self.children:
            if child is not None:
                yield child

    def reset_children(self):
        self.children = [None, None, None, None]

    def __repr__(self):
        return f"{self.__class__.__name__}({self.i}, {self.j}, {self.h}, {self.w})"


class QuadTree(object):
    def __init__(self, min_blk_size=256):
        super().__init__()
        self.min_blk_size = min_blk_size
        self.h = None
        self.w = None
        self.root = None

    def build_tree(self, mask_band, bg_cls=0):
        cls_info_table = self.preprocess(mask_band, bg_cls)
        n_rows = len(cls_info_table)
        if n_rows == 0:
            return None
        n_cols = len(cls_info_table[0])
        self.root = self._build_tree(cls_info_table, 0, n_rows - 1, 0,
                                     n_cols - 1, 0)
        return self.root

    def preprocess(self, mask_ds, bg_cls):
        h, w = mask_ds.RasterYSize, mask_ds.RasterXSize
        s = self.min_blk_size
        if s >= h or s >= w:
            raise ValueError("`min_blk_size` must be smaller than image size.")
        cls_info_table = []
        for i in range(0, h, s):
            cls_info_row = []
            for j in range(0, w, s):
                if i + s > h:
                    ch = h - i
                else:
                    ch = s
                if j + s > w:
                    cw = w - j
                else:
                    cw = s
                arr = mask_ds.ReadAsArray(j, i, cw, ch)
                bins = np.bincount(arr.ravel())
                if len(bins) > IGN_CLS:
                    bins = np.delete(bins, IGN_CLS)
                if len(bins) > bg_cls and bins.sum() == bins[bg_cls]:
                    cls_info_row.append(None)
                else:
                    cls_info_row.append(bins)
            cls_info_table.append(cls_info_row)
        return cls_info_table

    def _build_tree(self, cls_info_table, i_st, i_ed, j_st, j_ed, level=0):
        if i_ed < i_st or j_ed < j_st:
            return None

        i = i_st * self.min_blk_size
        j = j_st * self.min_blk_size
        h = (i_ed - i_st + 1) * self.min_blk_size
        w = (j_ed - j_st + 1) * self.min_blk_size

        if i_ed == i_st and j_ed == j_st:
            return QuadTreeNode(i, j, h, w, level, cls_info_table[i_st][j_st])

        i_mid = (i_ed + i_st) // 2
        j_mid = (j_ed + j_st) // 2

        root = QuadTreeNode(i, j, h, w, level)

        root.children[0] = self._build_tree(cls_info_table, i_st, i_mid, j_st,
                                            j_mid, level + 1)
        root.children[1] = self._build_tree(cls_info_table, i_st, i_mid,
                                            j_mid + 1, j_ed, level + 1)
        root.children[2] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
                                            j_st, j_mid, level + 1)
        root.children[3] = self._build_tree(cls_info_table, i_mid + 1, i_ed,
                                            j_mid + 1, j_ed, level + 1)

        bins_list = [
            node.cls_info for node in root.get_children()
            if node.cls_info is not None
        ]
        if len(bins_list) > 0:
            merged_bins = reduce(merge_bins, bins_list)
            root.cls_info = merged_bins
        else:
            # Merge nodes
            root.reset_children()

        return root

    def get_nodes(self, tar_cls=None, max_level=None, include_bg=True):
        nodes = []
        q = deque()
        q.append(self.root)
        while q:
            node = q.popleft()
            if max_level is None or node.level < max_level:
                for child in node.get_children():
                    if not include_bg and child.is_bg_node:
                        continue
                    if tar_cls is not None and child.get_cls_cnt(tar_cls) == 0:
                        continue
                    nodes.append(child)
                    q.append(child)
        return nodes

    def visualize_regions(self, im_path, save_path='./vis_quadtree.png'):
        im = paddlers.transforms.decode_image(im_path)
        if im.ndim == 2:
            im = np.stack([im] * 3, axis=2)
        elif im.ndim == 3:
            c = im.shape[2]
            if c < 3:
                raise ValueError(
                    "For multi-spectral images, the number of bands should not be less than 3."
                )
            else:
                # Take first three bands as R, G, and B
                im = im[..., :3]
        else:
            raise ValueError("Unrecognized data format.")
        nodes = self.get_nodes(include_bg=True)
        vis = np.ascontiguousarray(im)
        for node in nodes:
            i, j, h, w = node.coords
            vis = cv2.rectangle(vis, (j, i), (j + w, i + h), (0, 0, 255), 2)
        cv2.imwrite(save_path, vis)
        return save_path

    def print_tree(self, node=None, level=0):
        if node is None:
            node = self.root
        print(' ' * level + '-', node)
        for child in node.get_children():
            self.print_tree(child, level + 1)


def merge_bins(bins1, bins2):
    if len(bins1) < len(bins2):
        return merge_bins(bins2, bins1)
    elif len(bins1) == len(bins2):
        return bins1 + bins2
    else:
        return bins1 + np.concatenate(
            [bins2, np.zeros(len(bins1) - len(bins2))])


@time_it
def extract_ms_patches(im_paths,
                       mask_path,
                       save_dir,
                       min_patch_size=256,
                       bg_class=0,
                       target_class=None,
                       max_level=None,
                       include_bg=False,
                       nonzero_ratio=None,
                       visualize=False):
    def _save_patch(src_path, i, j, h, w, subdir=None):
        src_path = osp.normpath(src_path)
        src_name, src_ext = osp.splitext(osp.basename(src_path))
        subdir = subdir if subdir is not None else src_name
        dst_dir = osp.join(save_dir, subdir)
        if not osp.exists(dst_dir):
            os.makedirs(dst_dir)
        dst_name = FMT.format(idx=idx, ext=src_ext)
        dst_path = osp.join(dst_dir, dst_name)
        gdal.Translate(dst_path, src_path, srcWin=(j, i, w, h))
        return dst_path

    if nonzero_ratio is not None:
        print(
            "`nonzero_ratio` is not None. More time will be consumed to filter out all-zero patches."
        )

    mask_ds = gdal.Open(mask_path)
    quad_tree = QuadTree(min_blk_size=min_patch_size)
    if mask_ds.RasterCount != 1:
        raise ValueError("The mask image has more than 1 band.")
    print("Start building quad tree...")
    quad_tree.build_tree(mask_ds, bg_class)
    if visualize:
        print("Start drawing rectangles...")
        save_path = quad_tree.visualize_regions(im_paths[0])
        print(f"The visualization result is saved in {save_path} .")
    print("Quad tree has been built. Now start collecting nodes...")
    nodes = quad_tree.get_nodes(
        tar_cls=target_class, max_level=max_level, include_bg=include_bg)
    print("Nodes collected. Saving patches...")
    for idx, node in enumerate(tqdm(nodes)):
        i, j, h, w = node.coords
        real_h = min(h, mask_ds.RasterYSize - i)
        real_w = min(w, mask_ds.RasterXSize - j)
        if real_h < h or real_w < w:
            # Skip incomplete patches
            continue
        is_valid = True
        if nonzero_ratio is not None:
            for src_path in im_paths:
                im_ds = gdal.Open(src_path)
                arr = im_ds.ReadAsArray(j, i, real_w, real_h)
                if np.count_nonzero(arr) / arr.size < nonzero_ratio:
                    is_valid = False
                    break
        if is_valid:
            for src_path in im_paths:
                _save_patch(src_path, i, j, real_h, real_w)
            _save_patch(mask_path, i, j, real_h, real_w, 'mask')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--im_paths", type=str, required=True, nargs='+', \
                        help="Path of images. Different images must have unique file names.")
    parser.add_argument("--mask_path", type=str, required=True, \
                        help="Path of mask.")
    parser.add_argument("--save_dir", type=str, default='output', \
                        help="Path to save the extracted patches.")
    parser.add_argument("--min_patch_size", type=int, default=256, \
                        help="Minimum patch size (height and width).")
    parser.add_argument("--bg_class", type=int, default=0, \
                        help="Index of the background category.")
    parser.add_argument("--target_class", type=int, default=None, \
                        help="Index of the category of interest.")
    parser.add_argument("--max_level", type=int, default=None, \
                        help="Maximum level of hierarchical patches.")
    parser.add_argument("--include_bg", action='store_true', \
                        help="Include patches that contains only background pixels.")
    parser.add_argument("--nonzero_ratio", type=float, default=None, \
                        help="Threshold for filtering out less informative patches.")
    parser.add_argument("--visualize", action='store_true', \
                        help="Visualize the quadtree.")
    args = parser.parse_args()

    extract_ms_patches(args.im_paths, args.mask_path, args.save_dir,
                       args.min_patch_size, args.bg_class, args.target_class,
                       args.max_level, args.include_bg, args.nonzero_ratio,
                       args.visualize)