From e3f1c5b7e0ee24eb2cd5ddf4526831c1133eeb5d Mon Sep 17 00:00:00 2001
From: Vishwaesh Rajiv <vrajiv@vishwaeshs-mbp.lan>
Date: Tue, 6 Aug 2019 22:58:26 -0400
Subject: [PATCH 1/2] pyqt5 + dockerization

---
 docker/Dockerfile                      |  19 +
 docker/README.md                       |  26 ++
 docker/data/__init__.py                |   0
 docker/data/colorize_image.py          | 561 +++++++++++++++++++++++++
 docker/data/lab_gamut.py               |  91 ++++
 docker/ideepcolor_docker.py            |  86 ++++
 docker/install/docker_conda_install.sh |   9 +
 docker/install/docker_deps_install.sh  |  11 +
 docker/install/install_conda.sh        |   6 +
 docker/install/install_deps.sh         |   5 +
 docker/models/pytorch/fetch_model.sh   |   2 +
 docker/models/pytorch/model.py         | 175 ++++++++
 docker/ui_PyQt5/.gui_draw.py.swp       | Bin 0 -> 16384 bytes
 docker/ui_PyQt5/.gui_vis.py.swp        | Bin 0 -> 12288 bytes
 docker/ui_PyQt5/__init__.py            |   0
 docker/ui_PyQt5/gui_design.py          | 180 ++++++++
 docker/ui_PyQt5/gui_draw.py            | 379 +++++++++++++++++
 docker/ui_PyQt5/gui_gamut.py           | 106 +++++
 docker/ui_PyQt5/gui_palette.py         |  95 +++++
 docker/ui_PyQt5/gui_vis.py             |  98 +++++
 docker/ui_PyQt5/ui_control.py          | 192 +++++++++
 docker/ui_PyQt5/utils.py               | 108 +++++
 22 files changed, 2149 insertions(+)
 create mode 100644 docker/Dockerfile
 create mode 100644 docker/README.md
 create mode 100644 docker/data/__init__.py
 create mode 100644 docker/data/colorize_image.py
 create mode 100644 docker/data/lab_gamut.py
 create mode 100644 docker/ideepcolor_docker.py
 create mode 100755 docker/install/docker_conda_install.sh
 create mode 100755 docker/install/docker_deps_install.sh
 create mode 100644 docker/install/install_conda.sh
 create mode 100644 docker/install/install_deps.sh
 create mode 100644 docker/models/pytorch/fetch_model.sh
 create mode 100644 docker/models/pytorch/model.py
 create mode 100644 docker/ui_PyQt5/.gui_draw.py.swp
 create mode 100644 docker/ui_PyQt5/.gui_vis.py.swp
 create mode 100644 docker/ui_PyQt5/__init__.py
 create mode 100644 docker/ui_PyQt5/gui_design.py
 create mode 100644 docker/ui_PyQt5/gui_draw.py
 create mode 100644 docker/ui_PyQt5/gui_gamut.py
 create mode 100644 docker/ui_PyQt5/gui_palette.py
 create mode 100644 docker/ui_PyQt5/gui_vis.py
 create mode 100644 docker/ui_PyQt5/ui_control.py
 create mode 100644 docker/ui_PyQt5/utils.py

diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..277488c
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,19 @@
+FROM continuumio/miniconda3
+
+COPY ./install/docker_conda_install.sh /app/install/
+WORKDIR /app/install/
+
+RUN pwd
+RUN ls
+
+RUN conda create -n env python=3.6
+RUN echo "source activate env" > ~/.bashrc
+ENV PATH /opt/conda/envs/env/bin:$PATH
+
+RUN ls
+RUN ./docker_conda_install.sh
+
+COPY . /app
+WORKDIR /app
+
+CMD ["python","ideepcolor_docker.py"]
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000..2135aa2
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,26 @@
+# Using Docker
+
+### Note
+
+The following works for users on MacOS; I have not tested it on other platforms. 
+
+I've converted the PyQt4 code to PyQt5 in order to make this app easily runnable with Python 3. 
+
+Please follow this guide ([https://sourabhbajaj.com/blog/2017/02/07/gui-applications-docker-mac/](https://sourabhbajaj.com/blog/2017/02/07/gui-applications-docker-mac/)) to install & configure XQuartz (and Docker if you haven't already!)
+
+### Instructions
+Assuming you've git-cloned this repository and are currently in the `ideepcolor/` repository:
+
+    cd docker
+    bash models/pytorch/fetch_model.sh # Fetches the PyTorch model
+    docker build -t colorize .
+    docker image ls
+
+You should see a list of Docker images, one of them named `colorize`. 
+
+Then:
+
+    xhost + 127.0.0.1
+    docker run -e DISPLAY=host.docker.internal:0 colorize
+    
+in order to run the app!
diff --git a/docker/data/__init__.py b/docker/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/docker/data/colorize_image.py b/docker/data/colorize_image.py
new file mode 100644
index 0000000..a456fcc
--- /dev/null
+++ b/docker/data/colorize_image.py
@@ -0,0 +1,561 @@
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+from skimage import color
+from sklearn.cluster import KMeans
+import os
+from scipy.ndimage.interpolation import zoom
+
+
+def create_temp_directory(path_template, N=1e8):
+    print(path_template)
+    cur_path = path_template % np.random.randint(0, N)
+    while(os.path.exists(cur_path)):
+        cur_path = path_template % np.random.randint(0, N)
+    print('Creating directory: %s' % cur_path)
+    os.mkdir(cur_path)
+    return cur_path
+
+
+def lab2rgb_transpose(img_l, img_ab):
+    ''' INPUTS
+            img_l     1xXxX     [0,100]
+            img_ab     2xXxX     [-100,100]
+        OUTPUTS
+            returned value is XxXx3 '''
+    pred_lab = np.concatenate((img_l, img_ab), axis=0).transpose((1, 2, 0))
+    pred_rgb = (np.clip(color.lab2rgb(pred_lab), 0, 1) * 255).astype('uint8')
+    return pred_rgb
+
+
+def rgb2lab_transpose(img_rgb):
+    ''' INPUTS
+            img_rgb XxXx3
+        OUTPUTS
+            returned value is 3xXxX '''
+    return color.rgb2lab(img_rgb).transpose((2, 0, 1))
+
+
+class ColorizeImageBase():
+    def __init__(self, Xd=256, Xfullres_max=10000):
+        self.Xd = Xd
+        self.img_l_set = False
+        self.net_set = False
+        self.Xfullres_max = Xfullres_max  # maximum size of maximum dimension
+        self.img_just_set = False  # this will be true whenever image is just loaded
+        # net_forward can set this to False if they want
+
+    def prep_net(self):
+        raise Exception("Should be implemented by base class")
+
+    # ***** Image prepping *****
+    def load_image(self, input_path):
+        # rgb image [CxXdxXd]
+        im = cv2.cvtColor(cv2.imread(input_path, 1), cv2.COLOR_BGR2RGB)
+        self.img_rgb_fullres = im.copy()
+        self._set_img_lab_fullres_()
+
+        im = cv2.resize(im, (self.Xd, self.Xd))
+        self.img_rgb = im.copy()
+        # self.img_rgb = sp.misc.imresize(plt.imread(input_path),(self.Xd,self.Xd)).transpose((2,0,1))
+
+        self.img_l_set = True
+
+        # convert into lab space
+        self._set_img_lab_()
+        self._set_img_lab_mc_()
+
+    def set_image(self, input_image):
+        self.img_rgb_fullres = input_image.copy()
+        self._set_img_lab_fullres_()
+
+        self.img_l_set = True
+
+        self.img_rgb = input_image
+        # convert into lab space
+        self._set_img_lab_()
+        self._set_img_lab_mc_()
+
+    def net_forward(self, input_ab, input_mask):
+        # INPUTS
+        #     ab         2xXxX     input color patches (non-normalized)
+        #     mask     1xXxX    input mask, indicating which points have been provided
+        # assumes self.img_l_mc has been set
+
+        if(not self.img_l_set):
+            print('I need to have an image!')
+            return -1
+        if(not self.net_set):
+            print('I need to have a net!')
+            return -1
+
+        self.input_ab = input_ab
+        self.input_ab_mc = (input_ab - self.ab_mean) / self.ab_norm
+        self.input_mask = input_mask
+        self.input_mask_mult = input_mask * self.mask_mult
+        return 0
+
+    def get_result_PSNR(self, result=-1, return_SE_map=False):
+        if np.array((result)).flatten()[0] == -1:
+            cur_result = self.get_img_forward()
+        else:
+            cur_result = result.copy()
+        SE_map = (1. * self.img_rgb - cur_result)**2
+        cur_MSE = np.mean(SE_map)
+        cur_PSNR = 20 * np.log10(255. / np.sqrt(cur_MSE))
+        if return_SE_map:
+            return(cur_PSNR, SE_map)
+        else:
+            return cur_PSNR
+
+    def get_img_forward(self):
+        # get image with point estimate
+        return self.output_rgb
+
+    def get_img_gray(self):
+        # Get black and white image
+        return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
+
+    def get_img_gray_fullres(self):
+        # Get black and white image
+        return lab2rgb_transpose(self.img_l_fullres, np.zeros((2, self.img_l_fullres.shape[1], self.img_l_fullres.shape[2])))
+
+    def get_img_fullres(self):
+        # This assumes self.img_l_fullres, self.output_ab are set.
+        # Typically, this means that set_image() and net_forward()
+        # have been called.
+        # bilinear upsample
+        zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
+        output_ab_fullres = zoom(self.output_ab, zoom_factor, order=1)
+
+        return lab2rgb_transpose(self.img_l_fullres, output_ab_fullres)
+
+    def get_input_img_fullres(self):
+        zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
+        input_ab_fullres = zoom(self.input_ab, zoom_factor, order=1)
+        return lab2rgb_transpose(self.img_l_fullres, input_ab_fullres)
+
+    def get_input_img(self):
+        return lab2rgb_transpose(self.img_l, self.input_ab)
+
+    def get_img_mask(self):
+        # Get black and white image
+        return lab2rgb_transpose(100. * (1 - self.input_mask), np.zeros((2, self.Xd, self.Xd)))
+
+    def get_img_mask_fullres(self):
+        # Get black and white image
+        zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
+        input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
+        return lab2rgb_transpose(100. * (1 - input_mask_fullres), np.zeros((2, input_mask_fullres.shape[1], input_mask_fullres.shape[2])))
+
+    def get_sup_img(self):
+        return lab2rgb_transpose(50 * self.input_mask, self.input_ab)
+
+    def get_sup_fullres(self):
+        zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
+        input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
+        input_ab_fullres = zoom(self.input_ab, zoom_factor, order=0)
+        return lab2rgb_transpose(50 * input_mask_fullres, input_ab_fullres)
+
+    # ***** Private functions *****
+    def _set_img_lab_fullres_(self):
+        # adjust full resolution image to be within maximum dimension is within Xfullres_max
+        Xfullres = self.img_rgb_fullres.shape[0]
+        Yfullres = self.img_rgb_fullres.shape[1]
+        if Xfullres > self.Xfullres_max or Yfullres > self.Xfullres_max:
+            if Xfullres > Yfullres:
+                zoom_factor = 1. * self.Xfullres_max / Xfullres
+            else:
+                zoom_factor = 1. * self.Xfullres_max / Yfullres
+            self.img_rgb_fullres = zoom(self.img_rgb_fullres, (zoom_factor, zoom_factor, 1), order=1)
+
+        self.img_lab_fullres = color.rgb2lab(self.img_rgb_fullres).transpose((2, 0, 1))
+        self.img_l_fullres = self.img_lab_fullres[[0], :, :]
+        self.img_ab_fullres = self.img_lab_fullres[1:, :, :]
+
+    def _set_img_lab_(self):
+        # set self.img_lab from self.im_rgb
+        self.img_lab = color.rgb2lab(self.img_rgb).transpose((2, 0, 1))
+        self.img_l = self.img_lab[[0], :, :]
+        self.img_ab = self.img_lab[1:, :, :]
+
+    def _set_img_lab_mc_(self):
+        # set self.img_lab_mc from self.img_lab
+        # lab image, mean centered [XxYxX]
+        self.img_lab_mc = self.img_lab / np.array((self.l_norm, self.ab_norm, self.ab_norm))[:, np.newaxis, np.newaxis] - np.array(
+            (self.l_mean / self.l_norm, self.ab_mean / self.ab_norm, self.ab_mean / self.ab_norm))[:, np.newaxis, np.newaxis]
+        self._set_img_l_()
+
+    def _set_img_l_(self):
+        self.img_l_mc = self.img_lab_mc[[0], :, :]
+        self.img_l_set = True
+
+    def _set_img_ab_(self):
+        self.img_ab_mc = self.img_lab_mc[[1, 2], :, :]
+
+    def _set_out_ab_(self):
+        self.output_lab = rgb2lab_transpose(self.output_rgb)
+        self.output_ab = self.output_lab[1:, :, :]
+
+
+class ColorizeImageTorch(ColorizeImageBase):
+    def __init__(self, Xd=256, maskcent=False):
+        print('ColorizeImageTorch instantiated')
+        ColorizeImageBase.__init__(self, Xd)
+        self.l_norm = 1.
+        self.ab_norm = 1.
+        self.l_mean = 50.
+        self.ab_mean = 0.
+        self.mask_mult = 1.
+        self.mask_cent = .5 if maskcent else 0
+
+        # Load grid properties
+        self.pts_in_hull = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
+
+    # ***** Net preparation *****
+    def prep_net(self, gpu_id=None, path='', dist=False):
+        import torch
+        import models.pytorch.model as model
+        print('path = %s' % path)
+        print('Model set! dist mode? ', dist)
+        self.net = model.SIGGRAPHGenerator(dist=dist)
+        state_dict = torch.load(path)
+        if hasattr(state_dict, '_metadata'):
+            del state_dict._metadata
+
+        # patch InstanceNorm checkpoints prior to 0.4
+        for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
+            self.__patch_instance_norm_state_dict(state_dict, self.net, key.split('.'))
+        self.net.load_state_dict(state_dict)
+        if gpu_id != None:
+            self.net.cuda()
+        self.net.eval()
+        self.net_set = True
+
+    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+        key = keys[i]
+        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
+            if module.__class__.__name__.startswith('InstanceNorm') and \
+                    (key == 'running_mean' or key == 'running_var'):
+                if getattr(module, key) is None:
+                    state_dict.pop('.'.join(keys))
+            if module.__class__.__name__.startswith('InstanceNorm') and \
+               (key == 'num_batches_tracked'):
+                state_dict.pop('.'.join(keys))
+        else:
+            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+    # ***** Call forward *****
+    def net_forward(self, input_ab, input_mask):
+        # INPUTS
+        #     ab         2xXxX     input color patches (non-normalized)
+        #     mask     1xXxX    input mask, indicating which points have been provided
+        # assumes self.img_l_mc has been set
+
+        if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
+            return -1
+
+        # net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
+
+        # return prediction
+        # self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
+        # embed()
+        output_ab = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)[0, :, :, :].cpu().data.numpy()
+        self.output_rgb = lab2rgb_transpose(self.img_l, output_ab)
+        # self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
+
+        self._set_out_ab_()
+        return self.output_rgb
+
+    def get_img_forward(self):
+        # get image with point estimate
+        return self.output_rgb
+
+    def get_img_gray(self):
+        # Get black and white image
+        return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
+
+
+class ColorizeImageTorchDist(ColorizeImageTorch):
+    def __init__(self, Xd=256, maskcent=False):
+        ColorizeImageTorch.__init__(self, Xd)
+        self.dist_ab_set = False
+        self.pts_grid = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
+        self.in_hull = np.ones(529, dtype=bool)
+        self.AB = self.pts_grid.shape[0]  # 529
+        self.A = int(np.sqrt(self.AB))  # 23
+        self.B = int(np.sqrt(self.AB))  # 23
+        self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
+        self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
+        self.dist_entropy = np.zeros((self.Xd, self.Xd))
+        self.mask_cent = .5 if maskcent else 0
+
+    def prep_net(self, gpu_id=None, path='', dist=True, S=.2):
+        ColorizeImageTorch.prep_net(self, gpu_id=gpu_id, path=path, dist=dist)
+        # set S somehow
+
+    def net_forward(self, input_ab, input_mask):
+        # INPUTS
+        #     ab         2xXxX     input color patches (non-normalized)
+        #     mask     1xXxX    input mask, indicating which points have been provided
+        # assumes self.img_l_mc has been set
+
+        # embed()
+        if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
+            return -1
+
+        # set distribution
+        (function_return, self.dist_ab) = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)
+        function_return = function_return[0, :, :, :].cpu().data.numpy()
+        self.dist_ab = self.dist_ab[0, :, :, :].cpu().data.numpy()
+        self.dist_ab_set = True
+
+        # full grid, ABxXxX, AB = 529
+        self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
+
+        # gridded, AxBxXxX, A = 23
+        self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
+
+        # return
+        return function_return
+
+    def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
+        ''' Recommended colors at point (h,w)
+        Call this after calling net_forward
+        '''
+        if not self.dist_ab_set:
+            print('Need to set prediction first')
+            return 0
+
+        # randomly sample from pdf
+        cmf = np.cumsum(self.dist_ab[:, h, w])  # CMF
+        cmf = cmf / cmf[-1]
+        cmf_bins = cmf
+
+        # randomly sample N points
+        rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
+        inds = np.digitize(rnd_pts, bins=cmf_bins)
+        rnd_pts_ab = self.pts_in_hull[inds, :]
+
+        # run k-means
+        kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
+
+        # sort by cluster occupancy
+        k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
+        k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
+
+        cluster_per = 1. * k_label_cnt[k_inds] / N  # percentage of points within cluster
+        cluster_centers = kmeans.cluster_centers_[k_inds, :]  # cluster centers
+
+        # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
+        if return_conf:
+            return cluster_centers, cluster_per
+        else:
+            return cluster_centers
+
+    def compute_entropy(self):
+        # compute the distribution entropy (really slow right now)
+        self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
+
+    def plot_dist_grid(self, h, w):
+        # Plots distribution at a given point
+        plt.figure()
+        plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
+        plt.colorbar()
+        plt.ylabel('a')
+        plt.xlabel('b')
+
+    def plot_dist_entropy(self):
+        # Plots distribution at a given point
+        plt.figure()
+        plt.imshow(-self.dist_entropy, interpolation='nearest')
+        plt.colorbar()
+
+
+class ColorizeImageCaffe(ColorizeImageBase):
+    def __init__(self, Xd=256):
+        print('ColorizeImageCaffe instantiated')
+        ColorizeImageBase.__init__(self, Xd)
+        self.l_norm = 1.
+        self.ab_norm = 1.
+        self.l_mean = 50.
+        self.ab_mean = 0.
+        self.mask_mult = 110.
+
+        self.pred_ab_layer = 'pred_ab'  # predicted ab layer
+
+        # Load grid properties
+        self.pts_in_hull_path = './data/color_bins/pts_in_hull.npy'
+        self.pts_in_hull = np.load(self.pts_in_hull_path)  # 313x2, in-gamut
+
+    # ***** Net preparation *****
+    def prep_net(self, gpu_id, prototxt_path='', caffemodel_path=''):
+        import caffe
+        print('gpu_id = %d, net_path = %s, model_path = %s' % (gpu_id, prototxt_path, caffemodel_path))
+        if gpu_id == -1:
+            caffe.set_mode_cpu()
+        else:
+            caffe.set_device(gpu_id)
+            caffe.set_mode_gpu()
+        self.gpu_id = gpu_id
+        self.net = caffe.Net(prototxt_path, caffemodel_path, caffe.TEST)
+        self.net_set = True
+
+        # automatically set cluster centers
+        if len(self.net.params[self.pred_ab_layer][0].data[...].shape) == 4 and self.net.params[self.pred_ab_layer][0].data[...].shape[1] == 313:
+            print('Setting ab cluster centers in layer: %s' % self.pred_ab_layer)
+            self.net.params[self.pred_ab_layer][0].data[:, :, 0, 0] = self.pts_in_hull.T
+
+        # automatically set upsampling kernel
+        for layer in self.net._layer_names:
+            if layer[-3:] == '_us':
+                print('Setting upsampling layer kernel: %s' % layer)
+                self.net.params[layer][0].data[:, 0, :, :] = np.array(((.25, .5, .25, 0), (.5, 1., .5, 0), (.25, .5, .25, 0), (0, 0, 0, 0)))[np.newaxis, :, :]
+
+    # ***** Call forward *****
+    def net_forward(self, input_ab, input_mask):
+        # INPUTS
+        #     ab         2xXxX     input color patches (non-normalized)
+        #     mask     1xXxX    input mask, indicating which points have been provided
+        # assumes self.img_l_mc has been set
+
+        if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
+            return -1
+
+        net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
+
+        self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
+        self.net.forward()
+
+        # return prediction
+        self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
+
+        self._set_out_ab_()
+        return self.output_rgb
+
+    def get_img_forward(self):
+        # get image with point estimate
+        return self.output_rgb
+
+    def get_img_gray(self):
+        # Get black and white image
+        return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
+
+
+class ColorizeImageCaffeGlobDist(ColorizeImageCaffe):
+    # Caffe colorization, with additional global histogram as input
+    def __init__(self, Xd=256):
+        ColorizeImageCaffe.__init__(self, Xd)
+        self.glob_mask_mult = 1.
+        self.glob_layer = 'glob_ab_313_mask'
+
+    def net_forward(self, input_ab, input_mask, glob_dist=-1):
+        # glob_dist is 313 array, or -1
+        if np.array(glob_dist).flatten()[0] == -1:  # run without this, zero it out
+            self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = 0.
+            self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = 0.
+        else:  # run conditioned on global histogram
+            self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = glob_dist
+            self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = self.glob_mask_mult
+
+        self.output_rgb = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
+        self._set_out_ab_()
+        return self.output_rgb
+
+
+class ColorizeImageCaffeDist(ColorizeImageCaffe):
+    # caffe model which includes distribution prediction
+    def __init__(self, Xd=256):
+        ColorizeImageCaffe.__init__(self, Xd)
+        self.dist_ab_set = False
+        self.scale_S_layer = 'scale_S'
+        self.dist_ab_S_layer = 'dist_ab_S'  # softened distribution layer
+        self.pts_grid = np.load('./data/color_bins/pts_grid.npy')  # 529x2, all points
+        self.in_hull = np.load('./data/color_bins/in_hull.npy')  # 529 bool
+        self.AB = self.pts_grid.shape[0]  # 529
+        self.A = int(np.sqrt(self.AB))  # 23
+        self.B = int(np.sqrt(self.AB))  # 23
+        self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
+        self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
+        self.dist_entropy = np.zeros((self.Xd, self.Xd))
+
+    def prep_net(self, gpu_id, prototxt_path='', caffemodel_path='', S=.2):
+        ColorizeImageCaffe.prep_net(self, gpu_id, prototxt_path=prototxt_path, caffemodel_path=caffemodel_path)
+        self.S = S
+        self.net.params[self.scale_S_layer][0].data[...] = S
+
+    def net_forward(self, input_ab, input_mask):
+        # INPUTS
+        #     ab         2xXxX     input color patches (non-normalized)
+        #     mask     1xXxX    input mask, indicating which points have been provided
+        # assumes self.img_l_mc has been set
+
+        function_return = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
+        if np.array(function_return).flatten()[0] == -1:  # errored out
+            return -1
+
+        # set distribution
+        # in-gamut, CxXxX, C = 313
+        self.dist_ab = self.net.blobs[self.dist_ab_S_layer].data[0, :, :, :]
+        self.dist_ab_set = True
+
+        # full grid, ABxXxX, AB = 529
+        self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
+
+        # gridded, AxBxXxX, A = 23
+        self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
+
+        # return
+        return function_return
+
+    def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
+        ''' Recommended colors at point (h,w)
+        Call this after calling net_forward
+        '''
+        if not self.dist_ab_set:
+            print('Need to set prediction first')
+            return 0
+
+        # randomly sample from pdf
+        cmf = np.cumsum(self.dist_ab[:, h, w])  # CMF
+        cmf = cmf / cmf[-1]
+        cmf_bins = cmf
+
+        # randomly sample N points
+        rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
+        inds = np.digitize(rnd_pts, bins=cmf_bins)
+        rnd_pts_ab = self.pts_in_hull[inds, :]
+
+        # run k-means
+        kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
+
+        # sort by cluster occupancy
+        k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
+        k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
+
+        cluster_per = 1. * k_label_cnt[k_inds] / N  # percentage of points within cluster
+        cluster_centers = kmeans.cluster_centers_[k_inds, :]  # cluster centers
+
+        # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
+        if return_conf:
+            return cluster_centers, cluster_per
+        else:
+            return cluster_centers
+
+    def compute_entropy(self):
+        # compute the distribution entropy (really slow right now)
+        self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
+
+    def plot_dist_grid(self, h, w):
+        # Plots distribution at a given point
+        plt.figure()
+        plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
+        plt.colorbar()
+        plt.ylabel('a')
+        plt.xlabel('b')
+
+    def plot_dist_entropy(self):
+        # Plots distribution at a given point
+        plt.figure()
+        plt.imshow(-self.dist_entropy, interpolation='nearest')
+        plt.colorbar()
diff --git a/docker/data/lab_gamut.py b/docker/data/lab_gamut.py
new file mode 100644
index 0000000..64b216c
--- /dev/null
+++ b/docker/data/lab_gamut.py
@@ -0,0 +1,91 @@
+import numpy as np
+from skimage import color
+import warnings
+
+
+def qcolor2lab_1d(qc):
+    # take 1d numpy array and do color conversion
+    c = np.array([qc.red(), qc.green(), qc.blue()], np.uint8)
+    return rgb2lab_1d(c)
+
+
+def rgb2lab_1d(in_rgb):
+    # take 1d numpy array and do color conversion
+    print(in_rgb.shape)
+    # return color.rgb2lab(in_rgb[np.newaxis, np.newaxis, :]).flatten()
+    return color.rgb2lab(in_rgb).flatten()
+
+
+def lab2rgb_1d(in_lab, clip=True, dtype='uint8'):
+    warnings.filterwarnings("ignore")
+    tmp_rgb = color.lab2rgb(in_lab[np.newaxis, np.newaxis, :]).flatten()
+    if clip:
+        tmp_rgb = np.clip(tmp_rgb, 0, 1)
+    if dtype == 'uint8':
+        tmp_rgb = np.round(tmp_rgb * 255).astype('uint8')
+    return tmp_rgb
+
+
+def snap_ab(input_l, input_rgb, return_type='rgb'):
+    ''' given an input lightness and rgb, snap the color into a region where l,a,b is in-gamut
+    '''
+    T = 20
+    warnings.filterwarnings("ignore")
+    input_lab = rgb2lab_1d(np.array([[input_rgb]]))  # convert input to lab
+    conv_lab = input_lab.copy()  # keep ab from input
+    for t in range(T):
+        conv_lab[0] = input_l  # overwrite input l with input ab
+        old_lab = conv_lab
+        tmp_rgb = color.lab2rgb(conv_lab[np.newaxis, np.newaxis, :]).flatten()
+        tmp_rgb = np.clip(tmp_rgb, 0, 1)
+        conv_lab = color.rgb2lab(tmp_rgb[np.newaxis, np.newaxis, :]).flatten()
+        dif_lab = np.sum(np.abs(conv_lab - old_lab))
+        if dif_lab < 1:
+            break
+        # print(conv_lab)
+
+    conv_rgb_ingamut = lab2rgb_1d(conv_lab, clip=True, dtype='uint8')
+    if (return_type == 'rgb'):
+        return conv_rgb_ingamut
+
+    elif(return_type == 'lab'):
+        conv_lab_ingamut = rgb2lab_1d(conv_rgb_ingamut)
+        return conv_lab_ingamut
+
+
+class abGrid():
+    def __init__(self, gamut_size=110, D=1):
+        self.D = D
+        self.vals_b, self.vals_a = np.meshgrid(np.arange(-gamut_size, gamut_size + D, D),
+                                               np.arange(-gamut_size, gamut_size + D, D))
+        self.pts_full_grid = np.concatenate((self.vals_a[:, :, np.newaxis], self.vals_b[:, :, np.newaxis]), axis=2)
+        self.A = self.pts_full_grid.shape[0]
+        self.B = self.pts_full_grid.shape[1]
+        self.AB = self.A * self.B
+        self.gamut_size = gamut_size
+
+    def update_gamut(self, l_in):
+        warnings.filterwarnings("ignore")
+        thresh = 1.0
+        pts_lab = np.concatenate((l_in + np.zeros((self.A, self.B, 1)), self.pts_full_grid), axis=2)
+        self.pts_rgb = (255 * np.clip(color.lab2rgb(pts_lab), 0, 1)).astype('uint8')
+        pts_lab_back = color.rgb2lab(self.pts_rgb)
+        pts_lab_diff = np.linalg.norm(pts_lab - pts_lab_back, axis=2)
+
+        self.mask = pts_lab_diff < thresh
+        mask3 = np.tile(self.mask[..., np.newaxis], [1, 1, 3])
+        self.masked_rgb = self.pts_rgb.copy()
+        self.masked_rgb[np.invert(mask3)] = 255
+        return self.masked_rgb, self.mask
+
+    def ab2xy(self, a, b):
+        y = self.gamut_size + a
+        x = self.gamut_size + b
+        # print('ab2xy (%d, %d) -> (%d, %d)' % (a, b, x, y))
+        return x, y
+
+    def xy2ab(self, x, y):
+        a = y - self.gamut_size
+        b = x - self.gamut_size
+        # print('xy2ab (%d, %d) -> (%d, %d)' % (x, y, a, b))
+        return a, b
diff --git a/docker/ideepcolor_docker.py b/docker/ideepcolor_docker.py
new file mode 100644
index 0000000..ba6c31c
--- /dev/null
+++ b/docker/ideepcolor_docker.py
@@ -0,0 +1,86 @@
+from __future__ import print_function
+import sys
+from PyQt5.QtWidgets import *
+import argparse
+from PyQt5.QtWidgets import QApplication
+from PyQt5.QtGui import QIcon
+from PyQt5.QtCore import Qt
+from ui_PyQt5 import gui_design
+from data import colorize_image as CI
+
+sys.path.append('./caffe_files')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='iDeepColor: deep interactive colorization')
+    # basic parameters
+    parser.add_argument('--win_size', dest='win_size', help='the size of the main window', type=int, default=512)
+    parser.add_argument('--image_file', dest='image_file', help='input image', type=str, default='test_imgs/mortar_pestle.jpg')
+    parser.add_argument('--gpu', dest='gpu', help='gpu id', type=int, default=0)
+    parser.add_argument('--cpu_mode', dest='cpu_mode', help='do not use gpu', action='store_true')
+
+    # Caffe - Main colorization model
+    parser.add_argument('--color_prototxt', dest='color_prototxt', help='colorization caffe prototxt', type=str,
+                        default='./models/reference_model/deploy_nodist.prototxt')
+    parser.add_argument('--color_caffemodel', dest='color_caffemodel', help='colorization caffe prototxt', type=str,
+                        default='./models/reference_model/model.caffemodel')
+
+    # Caffe - Distribution prediction model
+    parser.add_argument('--dist_prototxt', dest='dist_prototxt', type=str, help='distribution net prototxt',
+                        default='./models/reference_model/deploy_nopred.prototxt')
+    parser.add_argument('--dist_caffemodel', dest='dist_caffemodel', type=str, help='distribution net caffemodel',
+                        default='./models/reference_model/model.caffemodel')
+
+    # PyTorch (same model used for both)
+    parser.add_argument('--color_model', dest='color_model', help='colorization model', type=str,
+                        default='./models/pytorch/caffemodel.pth')
+    parser.add_argument('--dist_model', dest='color_model', help='colorization distribution prediction model', type=str,
+                        default='./models/pytorch/caffemodel.pth')
+
+    parser.add_argument('--backend', dest='backend', type=str, help='caffe or pytorch', default='pytorch')
+    parser.add_argument('--pytorch_maskcent', dest='pytorch_maskcent', help='need to center mask (activate for siggraph_pretrained but not for converted caffemodel)', action='store_true')
+
+    # ***** DEPRECATED *****
+    parser.add_argument('--load_size', dest='load_size', help='image size', type=int, default=256)
+
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    for arg in vars(args):
+        print('[%s] =' % arg, getattr(args, arg))
+
+    if args.cpu_mode:
+        args.gpu = -1
+
+    args.win_size = int(args.win_size / 4.0) * 4  # make sure the width of the image can be divided by 4
+
+    if args.backend == 'caffe':
+        # initialize the colorization model
+        colorModel = CI.ColorizeImageCaffe(Xd=args.load_size)
+        colorModel.prep_net(args.gpu, args.color_prototxt, args.color_caffemodel)
+
+        distModel = CI.ColorizeImageCaffeDist(Xd=args.load_size)
+        distModel.prep_net(args.gpu, args.dist_prototxt, args.dist_caffemodel)
+    elif args.backend == 'pytorch':
+        colorModel = CI.ColorizeImageTorch(Xd=args.load_size,maskcent=args.pytorch_maskcent)
+        colorModel.prep_net(path=args.color_model)
+
+        distModel = CI.ColorizeImageTorchDist(Xd=args.load_size,maskcent=args.pytorch_maskcent)
+        distModel.prep_net(path=args.color_model, dist=True)
+    else:
+        print('backend type [%s] not found!' % args.backend)
+
+    # initialize application
+    app = QApplication(sys.argv)
+    window = gui_design.GUIDesign(color_model=colorModel, dist_model=distModel,
+                                  img_file=args.image_file, load_size=args.load_size, win_size=args.win_size)
+    #app.setStyleSheet(qdarkstyle.load_stylesheet(pyside=False))  # comment this if you do not like dark stylesheet
+    app.setWindowIcon(QIcon('imgs/logo.png'))  # load logo
+    window.setWindowTitle('iColor')
+    window.setWindowFlags(window.windowFlags() & ~Qt.WindowMaximizeButtonHint)   # fix window siz
+    window.show()
+    app.exec_()
diff --git a/docker/install/docker_conda_install.sh b/docker/install/docker_conda_install.sh
new file mode 100755
index 0000000..943ff36
--- /dev/null
+++ b/docker/install/docker_conda_install.sh
@@ -0,0 +1,9 @@
+apt update
+apt-get -y install libgl1-mesa-glx
+apt-get -y install libqt5x11extras5
+conda install -n env -c anaconda protobuf  ## photobuf
+conda install -n env -c anaconda scikit-learn ## scikit-learn
+conda install -n env -c anaconda scikit-image  ## scikit-image
+conda install -n env -c menpo opencv   ## opencv
+conda install -n env pyqt ## qt5
+conda install -n env -c pytorch pytorch torchvision cudatoolkit=9.0 
diff --git a/docker/install/docker_deps_install.sh b/docker/install/docker_deps_install.sh
new file mode 100755
index 0000000..70825ce
--- /dev/null
+++ b/docker/install/docker_deps_install.sh
@@ -0,0 +1,11 @@
+
+wget --quiet https://repo.anaconda.com/miniconda/Miniconda2-4.7.10-Linux-x86_64.sh -O ~/miniconda.sh && \
+    /bin/bash ~/miniconda.sh -b -p /opt/conda && \
+    rm ~/miniconda.sh && \
+    ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
+    echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
+    echo "conda activate base" >> ~/.bashrc
+
+# apt-get update
+# apt-get install python3-pip python-opencv 
+# pip install scikit-image scikit-learn qdarkstyle
diff --git a/docker/install/install_conda.sh b/docker/install/install_conda.sh
new file mode 100644
index 0000000..259b73c
--- /dev/null
+++ b/docker/install/install_conda.sh
@@ -0,0 +1,6 @@
+conda install -c anaconda protobuf  ## photobuf
+conda install -c anaconda scikit-learn ## scikit-learn
+conda install -c anaconda scikit-image  ## scikit-image
+conda install -c menpo opencv   ## opencv
+conda install pyqt=4.11 ## qt4
+conda install -c auto qdarkstyle  ## qdarkstyle
diff --git a/docker/install/install_deps.sh b/docker/install/install_deps.sh
new file mode 100644
index 0000000..eaf998f
--- /dev/null
+++ b/docker/install/install_deps.sh
@@ -0,0 +1,5 @@
+sudo pip install scikit-image
+sudo pip install scikit-learn
+sudo apt-get install python-opencv
+sudo apt-get install python-qt4
+sudo pip install qdarkstyle
diff --git a/docker/models/pytorch/fetch_model.sh b/docker/models/pytorch/fetch_model.sh
new file mode 100644
index 0000000..e299515
--- /dev/null
+++ b/docker/models/pytorch/fetch_model.sh
@@ -0,0 +1,2 @@
+curl -O http://colorization.eecs.berkeley.edu/siggraph/models/caffemodel.pth
+
diff --git a/docker/models/pytorch/model.py b/docker/models/pytorch/model.py
new file mode 100644
index 0000000..2c8cf0e
--- /dev/null
+++ b/docker/models/pytorch/model.py
@@ -0,0 +1,175 @@
+import torch
+import torch.nn as nn
+
+
+class SIGGRAPHGenerator(nn.Module):
+    def __init__(self, dist=False):
+        super(SIGGRAPHGenerator, self).__init__()
+        self.dist = dist
+        use_bias = True
+        norm_layer = nn.BatchNorm2d
+
+        # Conv1
+        model1 = [nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model1 += [nn.ReLU(True), ]
+        model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model1 += [nn.ReLU(True), ]
+        model1 += [norm_layer(64), ]
+        # add a subsampling operation
+
+        # Conv2
+        model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model2 += [nn.ReLU(True), ]
+        model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model2 += [nn.ReLU(True), ]
+        model2 += [norm_layer(128), ]
+        # add a subsampling layer operation
+
+        # Conv3
+        model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model3 += [nn.ReLU(True), ]
+        model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model3 += [nn.ReLU(True), ]
+        model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model3 += [nn.ReLU(True), ]
+        model3 += [norm_layer(256), ]
+        # add a subsampling layer operation
+
+        # Conv4
+        model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model4 += [nn.ReLU(True), ]
+        model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model4 += [nn.ReLU(True), ]
+        model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model4 += [nn.ReLU(True), ]
+        model4 += [norm_layer(512), ]
+
+        # Conv5
+        model5 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
+        model5 += [nn.ReLU(True), ]
+        model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
+        model5 += [nn.ReLU(True), ]
+        model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
+        model5 += [nn.ReLU(True), ]
+        model5 += [norm_layer(512), ]
+
+        # Conv6
+        model6 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
+        model6 += [nn.ReLU(True), ]
+        model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
+        model6 += [nn.ReLU(True), ]
+        model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
+        model6 += [nn.ReLU(True), ]
+        model6 += [norm_layer(512), ]
+
+        # Conv7
+        model7 = [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model7 += [nn.ReLU(True), ]
+        model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model7 += [nn.ReLU(True), ]
+        model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model7 += [nn.ReLU(True), ]
+        model7 += [norm_layer(512), ]
+
+        # Conv7
+        model8up = [nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)]
+        model3short8 = [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+
+        model8 = [nn.ReLU(True), ]
+        model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model8 += [nn.ReLU(True), ]
+        model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model8 += [nn.ReLU(True), ]
+        model8 += [norm_layer(256), ]
+
+        # Conv9
+        model9up = [nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ]
+        model2short9 = [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        # add the two feature maps above
+
+        model9 = [nn.ReLU(True), ]
+        model9 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        model9 += [nn.ReLU(True), ]
+        model9 += [norm_layer(128), ]
+
+        # Conv10
+        model10up = [nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ]
+        model1short10 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
+        # add the two feature maps above
+
+        model10 = [nn.ReLU(True), ]
+        model10 += [nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias), ]
+        model10 += [nn.LeakyReLU(negative_slope=.2), ]
+
+        # classification output
+        model_class = [nn.Conv2d(256, 529, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ]
+
+        # regression output
+        model_out = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ]
+        model_out += [nn.Tanh()]
+
+        self.model1 = nn.Sequential(*model1)
+        self.model2 = nn.Sequential(*model2)
+        self.model3 = nn.Sequential(*model3)
+        self.model4 = nn.Sequential(*model4)
+        self.model5 = nn.Sequential(*model5)
+        self.model6 = nn.Sequential(*model6)
+        self.model7 = nn.Sequential(*model7)
+        self.model8up = nn.Sequential(*model8up)
+        self.model8 = nn.Sequential(*model8)
+        self.model9up = nn.Sequential(*model9up)
+        self.model9 = nn.Sequential(*model9)
+        self.model10up = nn.Sequential(*model10up)
+        self.model10 = nn.Sequential(*model10)
+        self.model3short8 = nn.Sequential(*model3short8)
+        self.model2short9 = nn.Sequential(*model2short9)
+        self.model1short10 = nn.Sequential(*model1short10)
+
+        self.model_class = nn.Sequential(*model_class)
+        self.model_out = nn.Sequential(*model_out)
+
+        self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='nearest'), ])
+        self.softmax = nn.Sequential(*[nn.Softmax(dim=1), ])
+
+    def forward(self, input_A, input_B, mask_B, maskcent=0):
+        # input_A \in [-50,+50]
+        # input_B \in [-110, +110]
+        # mask_B \in [0, +1.0]
+
+        input_A = torch.Tensor(input_A)[None, :, :, :]
+        input_B = torch.Tensor(input_B)[None, :, :, :]
+        mask_B = torch.Tensor(mask_B)[None, :, :, :]
+        mask_B = mask_B - maskcent
+        
+        # input_A = torch.Tensor(input_A).cuda()[None, :, :, :]
+        # input_B = torch.Tensor(input_B).cuda()[None, :, :, :]
+        # mask_B = torch.Tensor(mask_B).cuda()[None, :, :, :]
+
+        conv1_2 = self.model1(torch.cat((input_A / 100., input_B / 110., mask_B), dim=1))
+        conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
+        conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
+        conv4_3 = self.model4(conv3_3[:, :, ::2, ::2])
+        conv5_3 = self.model5(conv4_3)
+        conv6_3 = self.model6(conv5_3)
+        conv7_3 = self.model7(conv6_3)
+
+        conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
+        conv8_3 = self.model8(conv8_up)
+
+        if(self.dist):
+            out_cl = self.upsample4(self.softmax(self.model_class(conv8_3) * .2))
+
+            conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
+            conv9_3 = self.model9(conv9_up)
+            conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
+            conv10_2 = self.model10(conv10_up)
+            out_reg = self.model_out(conv10_2) * 110
+
+            return (out_reg * 110, out_cl)
+        else:
+            conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
+            conv9_3 = self.model9(conv9_up)
+            conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
+            conv10_2 = self.model10(conv10_up)
+            out_reg = self.model_out(conv10_2)
+            return out_reg * 110
diff --git a/docker/ui_PyQt5/.gui_draw.py.swp b/docker/ui_PyQt5/.gui_draw.py.swp
new file mode 100644
index 0000000000000000000000000000000000000000..97231dde7c7874a192e433a800864bc030424215
GIT binary patch
literal 16384
zcmeHOTZklA87_12vN2wA5d|ejhuGbl-kO=-nJ~M_1hSLe*)hA>o!MPuSYoBR>r8iL
z>z1n1)6;<%^g#t#Lta)FA4H8WVnj$_1*1={BAOu4pbrM4iQ<DX1QdjY#P2_sy7W#a
zk`O_xV!rOGbN=(6bG~!?|G#>=GoM^J%ogk!3)lNB%X{SD-8bI;IqQ!fu&j;Lz1QEM
zjP0)tKO2l(o(-~HVK=b@H`-eMGfLsK?{S{=;vi0^y|{OZr_)2<IXYey^V5CYd8s?H
zlkwI%>eQNnnt?aUz?IhS&a7POQy=>%yZWjVZ&XoTu4bTSpk|<Epk|<Epk|<Epk|<E
z;D3+-kz8gyh<0D1+n(z0w@rM%s;_%=dNz^&iq5}Nr~i5)|5E+DroVUU3QRpZx&Epd
zs2Qjks2Qjks2Qjks2Qjks2Qjks2Qjks2O+*7;rty>O$`U3IO2!zw!THe6MBw33vuL
z2b=}I3}nDvz;WOfpbdNkxB}P#Tn@bS9?N<bcnWv|_#to>_$u%K&;#xQR)9sI1Iz+5
zzz2bsue7Y60bc`B;C6rk*8xwz+p-=5&I0!XU0@mb*A<raTi|iv^T18OwZMfPmh~j?
zec)lh0TuxUTmt;{U6%Dz;5_gh;50A-Hh?3*4&a~fw5)#szXDDJp9b~-*8<l77v5o6
z&jIIwGk^me1g-^MdAnu(4)_u91aJlz1D^tR10Mo@dAVgh0z3#fz+T|Pz)P2*KfsrO
zF0cn!0G>lI<qyF3fCm5$90slfUcfkh6Zi%|W4a%B>o2CSj0fv>KXu2>$dA|}>uq%G
zlxP0^yeYH%(4pMjZPrXjZI%vNZRSUUr%4>R!jGdxDzS3M@%?Mg(usX5OD&94S+`N3
zb@w*J68(;(sngrC#-$^Nj;uM$Yq#Fj*|)sb$~9=AEy{NLX}Ks$b*OkzRJMNTFU65a
z<G{|iSP9)eKSn(>U73&u1I_?-rGvZyQVXSzD1B5-_edCIG!<1OM#QmQL=&5yW}S%&
zJkqjtfsrXttpWtR=D4g6gW!z}OKO^(#>2>Krfer0uxU1D&$MtmSF|}b%`F5+TM8yA
zb*ho|+<>Ep*RMC&^-NMHlM0PR8pPIv*cDakK!PfVEKib|(~+V?J{pqdX{{l2H$jpo
z21Px25(?qIIcT#{kxNgQJVqfnCMxZ0;3hn;Z;JJ~U_Q+mwjQU<^Rpyy$9cYNi!KUH
zxWkmRdvP*u7B#6<WYZ$?!<4%oCO%X;>wdt?awbixG7q-Ld$GrxjiFdyXtbspf@i|<
z!+tg$#;I^qCqY8M?R%4cBX5}+jjgKDZ8qvFb>+oNLNp58gV^zOSY}3wV3!yM0-1Nj
zk;?S=I-@S4g{VS~EH;axyL%_L258{I5q`+i5@)LZ$QO>P0fhCP-Y`vhBnrl+)MSt<
zN`*_ZXqoOx-wFIooGf+B_wLU7%W!4m%;uZEC~2Z(J+H^{kPX}o%mY8d59h#$8w5-Y
zIC@JQWR0R-{<^7^SoT4lZia}38z^S7y%uSwtI{TR!y=otWV*@9P_5xXRYD|3d{gaa
zlH=ASuV`h3ymdd#O4JcYGU+0QW(rnh!^&-jaKIzATuB2_T6DX@G(`JB+_mYW8E1Cl
zih+G^>_>7aJs&%FW2gOzqx+W|t)f#{Zcf(~y#deisnd5LeNq1=2VGWPns~*Ovw4&+
zpj}u&Eykwg6lHFtW~QW@I4zS#3b&oP!^Cq1FU=O&gdX-&;bgp*v0Y{)Np&y4Zzdg?
zyY+yDTnu7bXN{3B2IzApish`SHPeaBRA<P<7)|63GcM0>=&r2clG1#UA8>=LSmjaz
zzwHdM0a42$tnvn?B^DkI)Ki>cPfReTqS_pux*0#L%50$!8=^EPpfTSzpj58|nz!;-
zrcF6eG$g(o(2*)rt2Kr3hEFZKxL?TOO>-J!uh-}D#B9xN1L!!hqV0sd*rF-bX8BpP
znBPsP!Zvom-{hV&E{cNYZbCZOf{OF|YTryrNSNEI08J4&Gyz0P4;UVLO8xGTj5F+D
z-1B=PCSxfc1|I8j#=}HRoMSY|q)C2PYf9n_8CwiizOQeYZ8TO2uHNZ;Smw6l$kpbk
z)gbGmby7)gtH!iTy;~$@<?-mu-&hRq;{0sKY!zyK>WvfaOUCz!V}3t!gJzW2k>{qc
zY1Oc3QMda#Vj1_;+$1`-(U%STeLAJ(jC~zhU3cFNhjfruwMz5oMs1F1NvX6dxq1wv
z9rasQFL1MrEuUDq4Zp^2YO0A4%MTM+3v|etTiTh;J?&(42Ak(Nv)0o%WZ5aXk1}0a
ztw5QB60S)O+^!?rZ>dr?wv3*#xs6D^xYY5l8;o9Hp31SGVn@f=03){B;+s952zG}X
z^8INVr|P6EiVP#4Hgrp*V}0rsd2d#o?3?^iyS*xwhrVh5PSrD$vlOSi$TI9~bgEKn
zre}p<Mf*j$T8*s3(hrH3k5M!XlQDBM79|$>|F6JDB>$59fAfBxeE-LRGr&DS6R?0+
z;h+BmI1S7J&%-Bw7$84=5Ac5A1^DAn0Z#(o1s(!2U;(%kcp85BW5Bn8F9Kfx0^n-k
zS@`6?2Ob5!0-OPcz#-se_~Pe)hk!A#36PKeao|eeZ}7!`0XzzffgaEWZU(Lb{sMpe
z&%n=tv%vkpF`x}xfFJ%0@EhP;z-NIja1*ctcme+S^T3aR(?9^+23!OD9o)SLoCkgY
zoCCfN5T`CsUo`{&zYH*arCbcth^<PqM%shM7-=hm%{)qLt+RmW!9sqRlZ$?k!U9y#
zoXey*!7-Ifg)uDDH6CzSjr%v?;c3&K(|x%Vm845HqZfzCP)s;RDqnpn<5e2z&d_g?
zLFh<hb&;&a!X6`O<FVzoE{t6%+bG;LBSXfJk6k#%5d~YM5RerrQ%T-#z-;wYl}*_f
z8=1|Lw<9}#814G6^jcct)oxSQ6-|YWFOB!Y^B@trN}WS5ESXmNfXNAqL;Axwg5ere
zBtV@a02Zl;BUtwEh*yQZ=Jy9;pG+(HuSTt`@=Xl$b{*~w+z7t8G-#UW(ChP}rbDG$
z)X;5sd97X_;_G>*O<y5PGp91I^13KTM-aB1Q>09XE9BH7%RkE+ZB0JZTr~TRlEk*t
zPj3JoyP0+7^CnH1ipbM&qaKG-Ax50T+aQ0C+N&afUcoFmx85c;2z=bYyfNJsJ_w#+
z{l?j>jo=x=r?~CVZC)ixX@Si`0JpgSdJn%|ZdX;WQ;|DM?9z%=UTfAd&uRrr|37mw
zSe<&!($KbLljLXaSWrE3*p)$EYC$P=+bu^Ak}A)-=`b6>8Wa?hEyYcX^rDZlh9Uz6
zLWEQhww(u?GE+ud;W#1UIit7el00~?2<qL4x<tdHRw<T7z;EM$aGNbT8q;jHqk;;=
z(rCbWu<e{lT+9KUBJ$#0#7-WK*Y1p#ZR+n<5g(*R8BMgitmF|z+e7SD_SQoj_r8@7
zt`A9@i&D^=RaL?dZ4qNvf``<kHy5Fb@}8W%yYw)F2PJPGh||y&&f4<6g@uJeF=mk;
zzQzirijk=osfq|%rZpTPmEx4Biki1BxiD%`v&aN(5Mq%DN@;b@k}ks#;ty-!o2T-?
zT|KIY;MPcnPuxuUhg5(rQwS00V2$^LKA%#$1x2e%^7YKzo;I7GYqPmMrD<Y{c6lE&
zkD69cr1OY+UB4cFnl>_X+i9aZzGYZVPZmXj%|nD~ORN+vm|`o95KdAy#JcF?Ji?mC
zUnSB<1QelTieHli)BY7j1z2;gvzv9t4B-U&+FevosukU&(2A4xyB6uB%OoI=q71#j
zPnz<O!T(oxJCJNj2|R#itj}3`-O`z#Z`p1p#tDV_v2+lnw4^$t5%SPT3gf5;Up=D8
zqhcVSmqRCQ+ip{S#7FL?pJ6h?Cr1Wqq36ZYTPT$Qb*0yQZWk|+WY@@?GmDGv|ATjZ
zrlqotQi%-w*zr)nLUY`%EmAkiu%V;Pvu)Nv&r1`kR3bw77w=g{ZtAJlAXMmf%{GRI
z(9KTGl(q<2Rtcd)xB*coXZPYw<zD<1k!+PieWr<RV;nk=L@2sdZW}d>hNK|h=$G}X
x8szQvrA$LnfTY<iYzF_SZ)oyE+MSWCcZ)m*X;bvukHA@J)lAqw)F7C|{{|QoH#YzP

literal 0
HcmV?d00001

diff --git a/docker/ui_PyQt5/.gui_vis.py.swp b/docker/ui_PyQt5/.gui_vis.py.swp
new file mode 100644
index 0000000000000000000000000000000000000000..73268f5a731e6a0a9544405f2352d4d43c580ae2
GIT binary patch
literal 12288
zcmeI2&u`pB6vwAPfwq*N2gHeiP_4IYoF7d~8VHnDO&X~|*+eac60F(v?5<m{?bzca
z%YmXIPV_?CODpjoP{9SM5?nX{{{SEa5=cl$+&K1v1mBtQ+TJ7tR1T=pSo(CmGw;25
z^PM+iN8QHg>A872X^#?I4-<0Z_Y=#zU%yDkA17orX6L+B*|Gh#>Lr~u#*<Dm+-*d*
z&w{P(-_`*~JeTvR8Tw&7;)c!hJRV6sXQ5XY2S?htLpVF?Z5_h28U>63x1&HRTEmlj
z$${}Pwbh57eu_T!=&9R@WPFSQMggOMQNSo*6fg=H1&jhlfjd+I5$zxsko#R({*CN&
z$A-`BtNAbr7zK<1MggOMQNSo*6fg=H1&jhl0i%FX;0{!PxrDrTHzC!XNFM+H-~apn
z*M|uC8r%Ta!37{d0s?Rb%z)>>KJW~966^xE9wg*9@E!OTd;>lQpMk633b+gyI0`CY
zC-`MA#srtZ``{dy2V>xm2MGBBd<ard0k`%L@)fuau7P*KMX(Hxfc>Bf9sxh@Cgcb3
zCHNRz1y?`>&VZM|7<dBg05^6K@)5WMTHpw%foH)z;OF}Z`3c+vE8q?AJoxiILOuo8
z!8MS8*T4k${$4^p0q=lWum{|P?w7#_zy}U^6C46^O)Ow8qkvK1e^!8ISHzNp$a|N!
zXg5p~zQ}#f5`KJ@2f|9Y->OkA@2gWp8A{{LUi0DV_Ab&*UQz5JZbkiGs#2rsvr1G~
z=^r_+oklJb+)*Yr?|Q<jFA2D}txAJ6&$r_+4cuYXJw^8=6}pc$?SOT;RjqCgIi$y;
z(^Lg-GT&>&d3>u`)9q21Z1!!!tA0%W#LH5{ogH1yQa?ME-;FPC;V6CjL&iKweQ~x|
zqw6&~wM^;Iz`VVd&On<{FQGvw=qq8sr*b>ksKu}wsSsgcRp~S`6ZT2o5=T`NWr3^X
zc!{%$NxDuHO3CHKrZ6UL`d;%q6&+4hM;;qV_4=tMkcXEXb0b87jm`o0mY#B6JPSxZ
z*|v0IF48!lNx&j!MTwcmF0~bUtHwg~(pRZ`>1$d3q8WjF)^OUan+jVsvxe2>EO1U{
zE9kH|X1!wRbZ1`(uj}|~#g_9lDXvNiNhd3Xm&#Cwf@=`jDyn5Q?U=h(6=T`$m~*U{
zylwcYDpH#zq8D*1k3FPjlvQd2+t8(9Ku_w1V*mZ)Qpp9`SmAmvqs~<&6v46;kumc;
zWFe@c354B4m5u0<Qe{z0dc9<|zNMASqg+cmhiLFLk%IP<pf66-QB}je;&z?BpvPbH
z0%y&0MJKZ;HFvnz?udK>GNzma<|-#J&@R)K?MOLq=YCS!X3{&$XojIDH^cser`(&x
zzLay^u-CF_${knetd5;*b5AS(Jgn-{$6aYh!ItYIxEv<StG?vD#nt(2=c;5!nEbBg
z(DPG^0P7QSMdeWCaGQsglON1gC#|xTv9fbnw`PwB{iPf@Vi=CcHV;t$g$goyzsrr;
z+FX~lx#h0asN1R0mAP)aSQKgNcCsQX=BnpLhiM0eIYP%bIIL}Q&?PfW*S4~kL&Br3
z6DdvH7Ga_vH`uet59w(mdnSy#OgM|PM<*vI2S<kyOy%>U94YCrXrPEnom(IDeTiD0
z?=SMEu=J4|%Uh(AJ&Yy}*XY5C8l5;?df@e4Q9O$fpT}z5>kHX>yc`G*j|G<CL0aVA
zcWUsq7w94(Qhs@5eNv<>NU2Cw=e8%lnk$l-@z%Nf@2gghtX<`^ND3-HiENga>MAN7
z{qjqKGAF9PY@*ly?<{obB7oe|xpVzh&uw#IJB}B4!f{HP$UQL7j5W{3o*qAV2)j^b
zIz4WW%Ke}smS1HsQq)@!OI|x*e*cvQJEZC3U5w72nsZLiEm_(4d4%lMN9rJA6?2F#
z^h5`bhXaJKH+h721-$zlkK-`bFARlC8hFjn<wV52>_95|{&P658-=l;k=r1xIP6lq
yv@x1B=FeV|H3Vrl>QRPwn@C%$AIjTxF`IgQr~QLHdLeM=>Tx#o9t&e$bo~WxJ&Mr)

literal 0
HcmV?d00001

diff --git a/docker/ui_PyQt5/__init__.py b/docker/ui_PyQt5/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/docker/ui_PyQt5/gui_design.py b/docker/ui_PyQt5/gui_design.py
new file mode 100644
index 0000000..35c2e54
--- /dev/null
+++ b/docker/ui_PyQt5/gui_design.py
@@ -0,0 +1,180 @@
+import PyQt5.QtCore as QtCore
+from PyQt5.QtCore import *
+from PyQt5.QtWidgets import *
+from PyQt5.QtGui import *
+from . import gui_draw
+from . import gui_vis
+from . import gui_gamut
+from . import gui_palette
+import time
+
+
+class GUIDesign(QWidget):
+    def __init__(self, color_model, dist_model=None, img_file=None, load_size=256,
+                 win_size=256, save_all=True):
+        # draw the layout
+        QWidget.__init__(self)
+        # main layout
+        mainLayout = QHBoxLayout()
+        self.setLayout(mainLayout)
+        # gamut layout
+        self.gamutWidget = gui_gamut.GUIGamut(gamut_size=160)
+        gamutLayout = self.AddWidget(self.gamutWidget, 'ab Color Gamut')
+        colorLayout = QVBoxLayout()
+
+        colorLayout.addLayout(gamutLayout)
+        mainLayout.addLayout(colorLayout)
+
+        # palette
+        self.customPalette = gui_palette.GUIPalette(grid_sz=(10, 1))
+        self.usedPalette = gui_palette.GUIPalette(grid_sz=(10, 1))
+        cpLayout = self.AddWidget(self.customPalette, 'Suggested colors')
+        colorLayout.addLayout(cpLayout)
+        upLayout = self.AddWidget(self.usedPalette, 'Recently used colors')
+        colorLayout.addLayout(upLayout)
+
+        self.colorPush = QPushButton()  # to visualize the selected color
+        self.colorPush.setFixedWidth(self.customPalette.width())
+        self.colorPush.setFixedHeight(25)
+        self.colorPush.setStyleSheet("background-color: grey")
+        colorPushLayout = self.AddWidget(self.colorPush, 'Color')
+        colorLayout.addLayout(colorPushLayout)
+        colorLayout.setAlignment(Qt.AlignTop)
+
+        # drawPad layout
+        drawPadLayout = QVBoxLayout()
+        mainLayout.addLayout(drawPadLayout)
+        self.drawWidget = gui_draw.GUIDraw(color_model, dist_model, load_size=load_size, win_size=win_size)
+        drawPadLayout = self.AddWidget(self.drawWidget, 'Drawing Pad')
+        mainLayout.addLayout(drawPadLayout)
+
+        drawPadMenu = QHBoxLayout()
+
+        self.bGray = QCheckBox("&Gray")
+        self.bGray.setToolTip('show gray-scale image')
+
+        self.bLoad = QPushButton('&Load')
+        self.bLoad.setToolTip('load an input image')
+        self.bSave = QPushButton("&Save")
+        self.bSave.setToolTip('Save the current result.')
+
+        drawPadMenu.addWidget(self.bGray)
+        drawPadMenu.addWidget(self.bLoad)
+        drawPadMenu.addWidget(self.bSave)
+
+        drawPadLayout.addLayout(drawPadMenu)
+        self.visWidget = gui_vis.GUI_VIS(win_size=win_size, scale=win_size / float(load_size))
+        visWidgetLayout = self.AddWidget(self.visWidget, 'Result')
+        mainLayout.addLayout(visWidgetLayout)
+
+        self.bRestart = QPushButton("&Restart")
+        self.bRestart.setToolTip('Restart the system')
+
+        self.bQuit = QPushButton("&Quit")
+        self.bQuit.setToolTip('Quit the system.')
+        visWidgetMenu = QHBoxLayout()
+        visWidgetMenu.addWidget(self.bRestart)
+
+        visWidgetMenu.addWidget(self.bQuit)
+        visWidgetLayout.addLayout(visWidgetMenu)
+
+        self.drawWidget.update()
+        self.visWidget.update()
+        self.colorPush.clicked.connect(self.drawWidget.change_color)
+        # color indicator
+        self.drawWidget.update_color.connect(self.colorPush.setStyleSheet)
+        # update result
+        self.drawWidget.update_result.connect(self.visWidget.update_result)
+        # so with the above code, when drawWidget is updated, we want visWidget to update its result
+        # visWidget's update_result is just a function that repaints the visWidget.
+        # but below, we also want to whenever visWidget gets updated, get gamutWidget to do set_ab. 
+        self.drawWidget.update_result.connect(self.gamutWidget.set_ab)
+        # self.visWidget.boom.connect(self.gamutWidget.set_ab)
+        #self.drawWidget.update_result.connect(self.drawWidget.set_color)
+        self.visWidget.update_color.connect(self.colorPush.setStyleSheet)
+        # self.visWidget.update_color.connect(self.drawWidget.set_color)
+        # update gamut
+        self.drawWidget.update_gamut.connect(self.gamutWidget.set_gamut)
+        self.drawWidget.update_ab.connect(self.gamutWidget.set_ab)
+        self.gamutWidget.update_color.connect(self.drawWidget.set_color)
+        # connect palette
+        self.drawWidget.suggest_colors.connect(self.customPalette.set_colors)
+        # self.connect(self.drawWidget, SIGNAL('change_color_id'), self.customPalette.update_color_id)
+        self.customPalette.update_color.connect(self.drawWidget.set_color)
+        self.customPalette.update_color.connect(self.gamutWidget.set_ab)
+
+        self.drawWidget.used_colors.connect(self.usedPalette.set_colors)
+        self.usedPalette.update_color.connect(self.drawWidget.set_color)
+        self.usedPalette.update_color.connect(self.gamutWidget.set_ab)
+        # menu events
+        self.bGray.setChecked(True)
+        self.bRestart.clicked.connect(self.reset)
+        self.bQuit.clicked.connect(self.quit)
+        self.bGray.toggled.connect(self.enable_gray)
+        self.bSave.clicked.connect(self.save)
+        self.bLoad.clicked.connect(self.load)
+
+        self.start_t = time.time()
+
+        if img_file is not None:
+            self.drawWidget.init_result(img_file)
+
+    def AddWidget(self, widget, title):
+        widgetLayout = QVBoxLayout()
+        widgetBox = QGroupBox()
+        widgetBox.setTitle(title)
+        vbox_t = QVBoxLayout()
+        vbox_t.addWidget(widget)
+        widgetBox.setLayout(vbox_t)
+        widgetLayout.addWidget(widgetBox)
+
+        return widgetLayout
+
+    def nextImage(self):
+        self.drawWidget.nextImage()
+
+    def reset(self):
+        # self.start_t = time.time()
+        print('============================reset all=========================================')
+        self.visWidget.reset()
+        self.gamutWidget.reset()
+        self.customPalette.reset()
+        self.usedPalette.reset()
+        self.drawWidget.reset()
+        self.update()
+        self.colorPush.setStyleSheet("background-color: grey")
+
+    def enable_gray(self):
+        self.drawWidget.enable_gray()
+
+    def quit(self):
+        print('time spent = %3.3f' % (time.time() - self.start_t))
+        self.close()
+
+    def save(self):
+        print('time spent = %3.3f' % (time.time() - self.start_t))
+        self.drawWidget.save_result()
+
+    def load(self):
+        self.drawWidget.load_image()
+
+    def change_color(self):
+        print('change color')
+        self.drawWidget.change_color(use_suggest=True)
+
+    def keyPressEvent(self, event):
+        if event.key() == Qt.Key_R:
+            self.reset()
+
+        if event.key() == Qt.Key_Q:
+            self.save()
+            self.quit()
+
+        if event.key() == Qt.Key_S:
+            self.save()
+
+        if event.key() == Qt.Key_G:
+            self.bGray.toggle()
+
+        if event.key() == Qt.Key_L:
+            self.load()
diff --git a/docker/ui_PyQt5/gui_draw.py b/docker/ui_PyQt5/gui_draw.py
new file mode 100644
index 0000000..b26822c
--- /dev/null
+++ b/docker/ui_PyQt5/gui_draw.py
@@ -0,0 +1,379 @@
+import numpy as np
+import pdb
+import PyQt5
+from PyQt5.QtWidgets import *
+from PyQt5.QtWidgets import *
+import cv2
+from PyQt5.QtCore import *
+from PyQt5.QtGui import *
+try:
+    QString = unicode
+except NameError:
+    # Python 3
+    QString = str
+
+try:
+    QString = unicode
+except NameError:
+    # Python 3
+    QString = str
+
+from .ui_control import UIControl
+
+from data import lab_gamut
+from skimage import color
+import os
+import datetime
+import glob
+import sys
+
+
+class GUIDraw(QWidget):
+    update_color = pyqtSignal(QString)
+    update_gamut = pyqtSignal(np.float64)
+    suggest_colors = pyqtSignal(np.ndarray)
+    used_colors = pyqtSignal(np.ndarray)
+    update_ab = pyqtSignal(np.ndarray)
+    update_result = pyqtSignal(np.ndarray)
+
+    def __init__(self, model, dist_model=None, load_size=256, win_size=512):
+        QWidget.__init__(self)
+        self.model = None
+        self.image_file = None
+        self.pos = None
+        self.model = model
+        self.dist_model = dist_model  # distribution predictor, could be empty
+        self.win_size = win_size
+        self.load_size = load_size
+        self.setFixedSize(win_size, win_size)
+        self.uiControl = UIControl(win_size=win_size, load_size=load_size)
+        self.move(win_size, win_size)
+        self.movie = True
+        self.init_color()  # initialize color
+        self.im_gray3 = None
+        self.eraseMode = False
+        self.ui_mode = 'none'   # stroke or point
+        self.image_loaded = False
+        self.use_gray = True
+        self.total_images = 0
+        self.image_id = 0
+        self.method = 'with_dist'
+
+    def clock_count(self):
+        self.count_secs -= 1
+        self.update()
+
+    def init_result(self, image_file):
+        self.read_image(image_file.encode('utf-8'))  # read an image
+        self.reset()
+
+    def get_batches(self, img_dir):
+        self.img_list = glob.glob(os.path.join(img_dir, '*.JPEG'))
+        self.total_images = len(self.img_list)
+        img_first = self.img_list[0]
+        self.init_result(img_first)
+
+    def nextImage(self):
+        self.save_result()
+        self.image_id += 1
+        if self.image_id == self.total_images:
+            print('you have finished all the results')
+            sys.exit()
+        img_current = self.img_list[self.image_id]
+        # self.reset()
+        self.init_result(img_current)
+        self.reset_timer()
+
+    def read_image(self, image_file):
+        # self.result = None
+        self.image_loaded = True
+        self.image_file = image_file
+        print(image_file)
+        image_file = image_file.decode('utf8')#'test_imgs/mortar_pestle.jpg'
+        im_bgr = cv2.imread(image_file)
+        self.im_full = im_bgr.copy()
+        # get image for display
+        h, w, c = self.im_full.shape
+        max_width = max(h, w)
+        r = self.win_size / float(max_width)
+        self.scale = float(self.win_size) / self.load_size
+        print('scale = %f' % self.scale)
+        rw = int(round(r * w / 4.0) * 4)
+        rh = int(round(r * h / 4.0) * 4)
+
+        self.im_win = cv2.resize(self.im_full, (rw, rh), interpolation=cv2.INTER_CUBIC)
+
+        self.dw = int((self.win_size - rw) // 2)
+        self.dh = int((self.win_size - rh) // 2)
+        self.win_w = rw
+        self.win_h = rh
+        self.uiControl.setImageSize((rw, rh))
+        im_gray = cv2.cvtColor(im_bgr, cv2.COLOR_BGR2GRAY)
+        self.im_gray3 = cv2.cvtColor(im_gray, cv2.COLOR_GRAY2BGR)
+
+        self.gray_win = cv2.resize(self.im_gray3, (rw, rh), interpolation=cv2.INTER_CUBIC)
+        im_bgr = cv2.resize(im_bgr, (self.load_size, self.load_size), interpolation=cv2.INTER_CUBIC)
+        self.im_rgb = cv2.cvtColor(im_bgr, cv2.COLOR_BGR2RGB)
+        lab_win = color.rgb2lab(self.im_win[:, :, ::-1])
+
+        self.im_lab = color.rgb2lab(im_bgr[:, :, ::-1])
+        self.im_l = self.im_lab[:, :, 0]
+        self.l_win = lab_win[:, :, 0]
+        self.im_ab = self.im_lab[:, :, 1:]
+        self.im_size = self.im_rgb.shape[0:2]
+
+        self.im_ab0 = np.zeros((2, self.load_size, self.load_size))
+        self.im_mask0 = np.zeros((1, self.load_size, self.load_size))
+        self.brushWidth = 2 * self.scale
+
+        self.model.load_image(image_file)
+
+        if (self.dist_model is not None):
+            self.dist_model.set_image(self.im_rgb)
+            self.predict_color()
+
+    def update_im(self):
+        self.update()
+        QApplication.processEvents()
+
+    def update_ui(self, move_point=True):
+        if self.ui_mode == 'none':
+            return False
+        is_predict = False
+        snap_qcolor = self.calibrate_color(self.user_color, self.pos)
+        self.color = snap_qcolor
+        self.update_color.emit(QString('background-color: %s' % self.color.name()))
+
+        if self.ui_mode == 'point':
+            if move_point:
+                self.uiControl.movePoint(self.pos, snap_qcolor, self.user_color, self.brushWidth)
+            else:
+                self.user_color, self.brushWidth, isNew = self.uiControl.addPoint(self.pos, snap_qcolor, self.user_color, self.brushWidth)
+                if isNew:
+                    is_predict = True
+                    # self.predict_color()
+
+        if self.ui_mode == 'stroke':
+            self.uiControl.addStroke(self.prev_pos, self.pos, snap_qcolor, self.user_color, self.brushWidth)
+        if self.ui_mode == 'erase':
+            isRemoved = self.uiControl.erasePoint(self.pos)
+            if isRemoved:
+                is_predict = True
+                # self.predict_color()
+        return is_predict
+
+    def reset(self):
+        self.ui_mode = 'none'
+        self.pos = None
+        self.result = None
+        self.user_color = None
+        self.color = None
+        self.uiControl.reset()
+        self.init_color()
+        self.compute_result()
+        self.predict_color()
+        self.update()
+
+    def scale_point(self, pnt):
+        x = int((pnt.x() - self.dw) / float(self.win_w) * self.load_size)
+        y = int((pnt.y() - self.dh) / float(self.win_h) * self.load_size)
+        return x, y
+
+    def valid_point(self, pnt):
+        if pnt is None:
+            print('WARNING: no point\n')
+            return None
+        else:
+            if pnt.x() >= self.dw and pnt.y() >= self.dh and pnt.x() < self.win_size - self.dw and pnt.y() < self.win_size - self.dh:
+                x = int(np.round(pnt.x()))
+                y = int(np.round(pnt.y()))
+                return QPoint(x, y)
+            else:
+                print('WARNING: invalid point (%d, %d)\n' % (pnt.x(), pnt.y()))
+                return None
+
+    def init_color(self):
+        self.user_color = QColor(128, 128, 128)  # default color red
+        self.color = self.user_color
+
+    def change_color(self, pos=None):
+        if pos is not None:
+            x, y = self.scale_point(pos)
+            L = self.im_lab[y, x, 0]
+            self.update_gamut.emit(L)
+            rgb_colors = self.suggest_color(h=y, w=x, K=9)
+            rgb_colors[-1, :] = 0.5
+
+            self.suggest_colors.emit(rgb_colors)
+            used_colors = self.uiControl.used_colors()
+            if used_colors is not None:
+                self.used_colors.emit(used_colors)
+            snap_color = self.calibrate_color(self.user_color, pos)
+            c = np.array((snap_color.red(), snap_color.green(), snap_color.blue()), np.uint8)
+
+            self.update_ab.emit(c)
+
+    def calibrate_color(self, c, pos):
+        x, y = self.scale_point(pos)
+
+        # snap color based on L color
+        color_array = np.array((c.red(), c.green(), c.blue())).astype(
+            'uint8')
+        mean_L = self.im_l[y, x]
+        snap_color = lab_gamut.snap_ab(mean_L, color_array)
+        snap_qcolor = QColor(snap_color[0], snap_color[1], snap_color[2])
+        return snap_qcolor
+
+    def set_pos(self, pos):
+        self.pos = pos
+
+
+    def set_color(self, c_rgb):
+        c = QColor(c_rgb[0], c_rgb[1], c_rgb[2])
+        self.user_color = c
+        snap_qcolor = self.calibrate_color(c, self.pos)
+        self.color = snap_qcolor
+        self.update_color.emit(QString('background-color: %s' % self.color.name()))
+        self.uiControl.update_color(snap_qcolor, self.user_color)
+        self.compute_result()
+
+    def erase(self):
+        self.eraseMode = not self.eraseMode
+
+    def load_image(self):
+        img_path = unicode(QFileDialog.getOpenFileName(self, 'load an input image'))[0][0]
+        self.init_result(img_path)
+
+    def save_result(self):
+        path = os.path.abspath(self.image_file)
+        path, ext = os.path.splitext(path)
+
+        suffix = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
+        path = path.decode('utf8')
+        save_path = "_".join([path, self.method, suffix])
+
+        print('saving result to <%s>\n' % save_path)
+        if not os.path.exists(save_path):
+            os.mkdir(save_path)
+
+        np.save(os.path.join(save_path, 'im_l.npy'), self.model.img_l)
+        np.save(os.path.join(save_path, 'im_ab.npy'), self.im_ab0)
+        np.save(os.path.join(save_path, 'im_mask.npy'), self.im_mask0)
+
+        result_bgr = cv2.cvtColor(self.result, cv2.COLOR_RGB2BGR)
+        mask = self.im_mask0.transpose((1, 2, 0)).astype(np.uint8) * 255
+        cv2.imwrite(os.path.join(save_path, 'input_mask.png'), mask)
+        cv2.imwrite(os.path.join(save_path, 'ours.png'), result_bgr)
+        cv2.imwrite(os.path.join(save_path, 'ours_fullres.png'), self.model.get_img_fullres()[:, :, ::-1])
+        cv2.imwrite(os.path.join(save_path, 'input_fullres.png'), self.model.get_input_img_fullres()[:, :, ::-1])
+        cv2.imwrite(os.path.join(save_path, 'input.png'), self.model.get_input_img()[:, :, ::-1])
+        cv2.imwrite(os.path.join(save_path, 'input_ab.png'), self.model.get_sup_img()[:, :, ::-1])
+
+    def enable_gray(self):
+        self.use_gray = not self.use_gray
+        self.update()
+
+    def predict_color(self):
+        if self.dist_model is not None and self.image_loaded:
+            im, mask = self.uiControl.get_input()
+            im_mask0 = mask > 0.0
+            self.im_mask0 = im_mask0.transpose((2, 0, 1))
+            im_lab = color.rgb2lab(im).transpose((2, 0, 1))
+            self.im_ab0 = im_lab[1:3, :, :]
+
+            self.dist_model.net_forward(self.im_ab0, self.im_mask0)
+
+    def suggest_color(self, h, w, K=5):
+        if self.dist_model is not None and self.image_loaded:
+            ab, conf = self.dist_model.get_ab_reccs(h=h, w=w, K=K, N=25000, return_conf=True)
+            L = np.tile(self.im_lab[h, w, 0], (K, 1))
+            colors_lab = np.concatenate((L, ab), axis=1)
+            colors_lab3 = colors_lab[:, np.newaxis, :]
+            colors_rgb = np.clip(np.squeeze(color.lab2rgb(colors_lab3)), 0, 1)
+            colors_rgb_withcurr = np.concatenate((self.model.get_img_forward()[h, w, np.newaxis, :] / 255., colors_rgb), axis=0)
+            return colors_rgb_withcurr
+        else:
+            return None
+
+    def compute_result(self):
+        im, mask = self.uiControl.get_input()
+        im_mask0 = mask > 0.0
+        self.im_mask0 = im_mask0.transpose((2, 0, 1))
+        im_lab = color.rgb2lab(im).transpose((2, 0, 1))
+        self.im_ab0 = im_lab[1:3, :, :]
+
+        self.model.net_forward(self.im_ab0, self.im_mask0)
+        ab = self.model.output_ab.transpose((1, 2, 0))
+        ab_win = cv2.resize(ab, (self.win_w, self.win_h), interpolation=cv2.INTER_CUBIC)
+        pred_lab = np.concatenate((self.l_win[..., np.newaxis], ab_win), axis=2)
+        pred_rgb = (np.clip(color.lab2rgb(pred_lab), 0, 1) * 255).astype('uint8')
+        self.result = pred_rgb
+        # self.result is a numpy array (423 by 512 by 3)
+        # I'm assuming this is an image
+        self.update_result.emit(self.result)
+        self.update()
+
+    def paintEvent(self, event):
+        painter = QPainter()
+        painter.begin(self)
+        painter.fillRect(event.rect(), QColor(49, 54, 49))
+        painter.setRenderHint(QPainter.Antialiasing)
+        if self.use_gray or self.result is None:
+            im = self.gray_win
+        else:
+            im = self.result
+
+        if im is not None:
+            qImg = QImage(im.tostring(), im.shape[1], im.shape[0], QImage.Format_RGB888)
+            painter.drawImage(self.dw, self.dh, qImg)
+
+        self.uiControl.update_painter(painter)
+        painter.end()
+
+    def wheelEvent(self, event):
+        d = event.angleDelta().y() / 120
+        self.brushWidth = min(4.05 * self.scale, max(0, self.brushWidth + d * self.scale))
+        print('update brushWidth = %f' % self.brushWidth)
+        self.update_ui(move_point=True)
+        self.update()
+
+    def is_same_point(self, pos1, pos2):
+        if pos1 is None or pos2 is None:
+            return False
+        dx = pos1.x() - pos2.x()
+        dy = pos1.y() - pos2.y()
+        d = dx * dx + dy * dy
+        # print('distance between points = %f' % d)
+        return d < 25
+
+    def mousePressEvent(self, event):
+        print('mouse press', event.pos())
+        pos = self.valid_point(event.pos())
+
+        if pos is not None:
+            if event.button() == Qt.LeftButton:
+                self.pos = pos
+                self.ui_mode = 'point'
+                self.change_color(pos)
+                self.update_ui(move_point=False)
+                self.compute_result()
+
+            if event.button() == Qt.RightButton:
+                # draw the stroke
+                self.pos = pos
+                self.ui_mode = 'erase'
+                self.update_ui(move_point=False)
+                self.compute_result()
+
+    def mouseMoveEvent(self, event):
+        self.pos = self.valid_point(event.pos())
+        if self.pos is not None:
+            if self.ui_mode == 'point':
+                self.update_ui(move_point=True)
+                self.compute_result()
+
+    def mouseReleaseEvent(self, event):
+        pass
+
+    def sizeHint(self):
+        return QSize(self.win_size, self.win_size)  # 28 * 8
diff --git a/docker/ui_PyQt5/gui_gamut.py b/docker/ui_PyQt5/gui_gamut.py
new file mode 100644
index 0000000..a03b183
--- /dev/null
+++ b/docker/ui_PyQt5/gui_gamut.py
@@ -0,0 +1,106 @@
+import cv2
+import pdb
+from PyQt5.QtWidgets import *
+from PyQt5.QtWidgets import *
+from PyQt5.QtCore import *
+from PyQt5.QtGui import *
+from data import lab_gamut
+import numpy as np
+
+
+class GUIGamut(QWidget):
+    update_color = pyqtSignal(np.ndarray)
+
+    def __init__(self, gamut_size=110):
+        QWidget.__init__(self)
+        self.gamut_size = gamut_size
+        self.win_size = gamut_size * 2  # divided by 4
+        self.setFixedSize(self.win_size, self.win_size)
+        self.ab_grid = lab_gamut.abGrid(gamut_size=gamut_size, D=1)
+        self.reset()
+
+    def set_gamut(self, l_in=50):
+        self.l_in = l_in
+        self.ab_map, self.mask = self.ab_grid.update_gamut(l_in=l_in)
+        self.update()
+
+    def set_ab(self, color):
+        self.color = color
+        if len(self.color.shape) == 1:
+            self.lab = lab_gamut.rgb2lab_1d((np.array([[self.color]])))
+        else:
+            self.lab = lab_gamut.rgb2lab_1d(np.squeeze(np.array(self.color)))
+        x, y = self.ab_grid.ab2xy(self.lab[1], self.lab[2])
+        self.pos = QPointF(x, y)
+        self.update()
+
+    def is_valid_point(self, pos):
+        if pos is None:
+            return False
+        else:
+            x = pos.x()
+            y = pos.y()
+            if x >= 0 and y >= 0 and x < self.win_size and y < self.win_size:
+                return self.mask[y, x]
+            else:
+                return False
+
+    def update_ui(self, pos):
+        self.pos = pos
+        a, b = self.ab_grid.xy2ab(pos.x(), pos.y())
+        # get color we need L
+        L = self.l_in
+        lab = np.array([L, a, b])
+        color = lab_gamut.lab2rgb_1d(lab, clip=True, dtype='uint8')
+        self.update_color.emit(color)
+        self.update()
+
+    def paintEvent(self, event):
+        painter = QPainter()
+        painter.begin(self)
+        painter.setRenderHint(QPainter.Antialiasing)
+        painter.fillRect(event.rect(), Qt.white)
+        if self.ab_map is not None:
+            ab_map = cv2.resize(self.ab_map, (self.win_size, self.win_size))
+            qImg = QImage(ab_map.tostring(), self.win_size, self.win_size, QImage.Format_RGB888)
+            painter.drawImage(0, 0, qImg)
+
+        painter.setPen(QPen(Qt.gray, 3, Qt.DotLine, cap=Qt.RoundCap, join=Qt.RoundJoin))
+        painter.drawLine(self.win_size // 2, 0, self.win_size // 2, self.win_size)
+        painter.drawLine(0, self.win_size // 2, self.win_size, self.win_size // 2)
+        if self.pos is not None:
+            painter.setPen(QPen(Qt.black, 2, Qt.SolidLine, cap=Qt.RoundCap, join=Qt.RoundJoin))
+            w = 5
+            x = self.pos.x()
+            y = self.pos.y()
+            painter.drawLine(x - w, y, x + w, y)
+            painter.drawLine(x, y - w, x, y + w)
+        painter.end()
+
+    def mousePressEvent(self, event):
+        pos = event.pos()
+
+        if event.button() == Qt.LeftButton and self.is_valid_point(pos):  # click the point
+            self.update_ui(pos)
+            self.mouseClicked = True
+
+    def mouseMoveEvent(self, event):
+        pos = event.pos()
+        if self.is_valid_point(pos):
+            if self.mouseClicked:
+                self.update_ui(pos)
+
+    def mouseReleaseEvent(self, event):
+        self.mouseClicked = False
+
+    def sizeHint(self):
+        return QSize(self.win_size, self.win_size)
+
+    def reset(self):
+        self.ab_map = None
+        self.mask = None
+        self.color = None
+        self.lab = None
+        self.pos = None
+        self.mouseClicked = False
+        self.update()
diff --git a/docker/ui_PyQt5/gui_palette.py b/docker/ui_PyQt5/gui_palette.py
new file mode 100644
index 0000000..58f9a13
--- /dev/null
+++ b/docker/ui_PyQt5/gui_palette.py
@@ -0,0 +1,95 @@
+import pdb
+from PyQt5.QtCore import *
+from PyQt5.QtWidgets import *
+from PyQt5.QtWidgets import *
+from PyQt5.QtGui import *
+import numpy as np
+
+
+class GUIPalette(QWidget):
+    update_color = pyqtSignal(np.ndarray)
+
+    def __init__(self, grid_sz=(6, 3)):
+        QWidget.__init__(self)
+        self.color_width = 25
+        self.border = 6
+        self.win_width = grid_sz[0] * self.color_width + (grid_sz[0] + 1) * self.border
+        self.win_height = grid_sz[1] * self.color_width + (grid_sz[1] + 1) * self.border
+        self.setFixedSize(self.win_width, self.win_height)
+        self.num_colors = grid_sz[0] * grid_sz[1]
+        self.grid_sz = grid_sz
+        self.colors = None
+        self.color_id = -1
+        self.reset()
+
+    def set_colors(self, colors):
+        if colors is not None:
+            self.colors = (colors[:min(colors.shape[0], self.num_colors), :] * 255).astype(np.uint8)
+            self.color_id = -1
+            self.update()
+
+    def paintEvent(self, event):
+        painter = QPainter()
+        painter.begin(self)
+        painter.setRenderHint(QPainter.Antialiasing)
+        painter.fillRect(event.rect(), Qt.white)
+        if self.colors is not None:
+            for n, c in enumerate(self.colors):
+                ca = QColor(c[0], c[1], c[2], 255)
+                painter.setPen(QPen(Qt.black, 1))
+                painter.setBrush(ca)
+                grid_x = n % self.grid_sz[0]
+                grid_y = (n - grid_x) // self.grid_sz[0]
+                x = grid_x * (self.color_width + self.border) + self.border
+                y = grid_y * (self.color_width + self.border) + self.border
+
+                if n == self.color_id:
+                    painter.drawEllipse(x, y, self.color_width, self.color_width)
+                else:
+                    painter.drawRoundedRect(x, y, self.color_width, self.color_width, 2, 2)
+
+        painter.end()
+
+    def sizeHint(self):
+        return QSize(self.win_width, self.win_height)
+
+    def reset(self):
+        self.colors = None
+        self.mouseClicked = False
+        self.color_id = -1
+        self.update()
+
+    def selected_color(self, pos):
+        width = self.color_width + self.border
+        dx = pos.x() % width
+        dy = pos.y() % width
+        if dx >= self.border and dy >= self.border:
+            x_id = (pos.x() - dx) // width
+            y_id = (pos.y() - dy) // width
+            color_id = x_id + y_id * self.grid_sz[0]
+            return int(color_id)
+        else:
+            return -1
+
+    def update_ui(self, color_id):
+        self.color_id = int(color_id)
+        self.update()
+        if color_id >= 0:
+            print('choose color (%d) type (%s)' % (color_id, type(color_id)))
+            color = self.colors[color_id]
+            self.update_color.emit(color)
+            self.update()
+
+    def mousePressEvent(self, event):
+        if event.button() == Qt.LeftButton:  # click the point
+            color_id = self.selected_color(event.pos())
+            self.update_ui(color_id)
+            self.mouseClicked = True
+
+    def mouseMoveEvent(self, event):
+        if self.mouseClicked:
+            color_id = self.selected_color(event.pos())
+            self.update_ui(color_id)
+
+    def mouseReleaseEvent(self, event):
+        self.mouseClicked = False
diff --git a/docker/ui_PyQt5/gui_vis.py b/docker/ui_PyQt5/gui_vis.py
new file mode 100644
index 0000000..b8acaa3
--- /dev/null
+++ b/docker/ui_PyQt5/gui_vis.py
@@ -0,0 +1,98 @@
+from PyQt5.QtCore import *
+from PyQt5.QtWidgets import *
+from PyQt5.QtGui import *
+import numpy as np
+from data import lab_gamut
+import pdb
+
+try:
+    QString = unicode
+except NameError:
+    # Python 3
+    QString = str
+
+try:
+    QString = unicode
+except NameError:
+    # Python 3
+    QString = str
+
+
+class GUI_VIS(QWidget):
+
+    update_color = pyqtSignal(QString)
+
+
+    def __init__(self, win_size=256, scale=2.0):
+        QWidget.__init__(self)
+        self.result = None
+        self.win_width = win_size
+        self.win_height = win_size
+        self.scale = scale
+        self.setFixedSize(self.win_width, self.win_height)
+
+    def paintEvent(self, event):
+        painter = QPainter()
+        painter.begin(self)
+        painter.setRenderHint(QPainter.Antialiasing)
+        painter.fillRect(event.rect(), QColor(49, 54, 49))
+        if self.result is not None:
+            h, w, c = self.result.shape
+            qImg = QImage(self.result.tostring(), w, h, QImage.Format_RGB888)
+            dw = int((self.win_width - w) // 2)
+            dh = int((self.win_height - h) // 2)
+            painter.drawImage(dw, dh, qImg)
+
+        painter.end()
+
+    def update_result(self, result):
+        self.result = result
+        self.update()
+
+    def sizeHint(self):
+        return QSize(self.win_width, self.win_height)
+
+    def reset(self):
+        self.update()
+        self.result = None
+
+    def is_valid_point(self, pos):
+        if pos is None:
+            return False
+        else:
+            x = pos.x()
+            y = pos.y()
+            return x >= 0 and y >= 0 and x < self.win_width and y < self.win_height
+
+    def scale_point(self, pnt):
+        x = int(pnt.x() / self.scale)
+        y = int(pnt.y() / self.scale)
+        return x, y
+
+    def calibrate_color(self, c, pos):
+        x, y = self.scale_point(pos)
+
+        # snap color based on L color
+        color_array = np.array((c.red(), c.green(), c.blue())).astype(
+            'uint8')
+        mean_L = self.im_l[y, x]
+        snap_color = lab_gamut.snap_ab(mean_L, color_array)
+        snap_qcolor = QColor(snap_color[0], snap_color[1], snap_color[2])
+        return snap_qcolor
+
+    def mousePressEvent(self, event):
+        pos = event.pos()
+        x, y = self.scale_point(pos)
+        if event.button() == Qt.LeftButton and self.is_valid_point(pos):  # click the point
+            if self.result is not None:
+                color = self.result[y, x, :]  #
+                c = QColor(color[0], color[1], color[2])
+                # color = self.calibrate_color(c, self.pos)
+                self.update_color.emit(QString('background-color: %s' % c.name()))
+                print('color', color)
+
+    def mouseMoveEvent(self, event):
+        pass
+
+    def mouseReleaseEvent(self, event):
+        pass
diff --git a/docker/ui_PyQt5/ui_control.py b/docker/ui_PyQt5/ui_control.py
new file mode 100644
index 0000000..e4c6505
--- /dev/null
+++ b/docker/ui_PyQt5/ui_control.py
@@ -0,0 +1,192 @@
+import numpy as np
+from PyQt5.QtCore import *
+from PyQt5.QtGui import *
+import cv2
+
+
+class UserEdit(object):
+    def __init__(self, mode, win_size, load_size, img_size):
+        self.mode = mode
+        self.win_size = win_size
+        self.img_size = img_size
+        self.load_size = load_size
+        print('image_size', self.img_size)
+        max_width = np.max(self.img_size)
+        self.scale = float(max_width) / self.load_size
+        self.dw = int((self.win_size - img_size[0]) // 2)
+        self.dh = int((self.win_size - img_size[1]) // 2)
+        self.img_w = img_size[0]
+        self.img_h = img_size[1]
+        self.ui_count = 0
+        print(self)
+
+    def scale_point(self, in_x, in_y, w):
+        x = int((in_x - self.dw) / float(self.img_w) * self.load_size) + w
+        y = int((in_y - self.dh) / float(self.img_h) * self.load_size) + w
+        return x, y
+
+    def __str__(self):
+        return "add (%s) with win_size %3.3f, load_size %3.3f" % (self.mode, self.win_size, self.load_size)
+
+
+class PointEdit(UserEdit):
+    def __init__(self, win_size, load_size, img_size):
+        UserEdit.__init__(self, 'point', win_size, load_size, img_size)
+
+    def add(self, pnt, color, userColor, width, ui_count):
+        self.pnt = pnt
+        self.color = color
+        self.userColor = userColor
+        self.width = width
+        self.ui_count = ui_count
+
+    def select_old(self, pnt, ui_count):
+        self.pnt = pnt
+        self.ui_count = ui_count
+        return self.userColor, self.width
+
+    def update_color(self, color, userColor):
+        self.color = color
+        self.userColor = userColor
+
+    def updateInput(self, im, mask, vis_im):
+        w = int(self.width / self.scale)
+        pnt = self.pnt
+        x1, y1 = self.scale_point(pnt.x(), pnt.y(), -w)
+        tl = (x1, y1)
+        x2, y2 = self.scale_point(pnt.x(), pnt.y(), w)
+        br = (x2, y2)
+        c = (self.color.red(), self.color.green(), self.color.blue())
+        uc = (self.userColor.red(), self.userColor.green(), self.userColor.blue())
+        cv2.rectangle(mask, tl, br, 255, -1)
+        cv2.rectangle(im, tl, br, c, -1)
+        cv2.rectangle(vis_im, tl, br, uc, -1)
+
+    def is_same(self, pnt):
+        dx = abs(self.pnt.x() - pnt.x())
+        dy = abs(self.pnt.y() - pnt.y())
+        return dx <= self.width + 1 and dy <= self.width + 1
+
+    def update_painter(self, painter):
+        w = max(3, self.width)
+        c = self.color
+        r = c.red()
+        g = c.green()
+        b = c.blue()
+        ca = QColor(c.red(), c.green(), c.blue(), 255)
+        d_to_black = r * r + g * g + b * b
+        d_to_white = (255 - r) * (255 - r) + (255 - g) * (255 - g) + (255 - r) * (255 - r)
+        if d_to_black > d_to_white:
+            painter.setPen(QPen(Qt.black, 1))
+        else:
+            painter.setPen(QPen(Qt.white, 1))
+        painter.setBrush(ca)
+        painter.drawRoundedRect(self.pnt.x() - w, self.pnt.y() - w, 1 + 2 * w, 1 + 2 * w, 2, 2)
+
+
+class UIControl:
+    def __init__(self, win_size=256, load_size=512):
+        self.win_size = win_size
+        self.load_size = load_size
+        self.reset()
+        self.userEdit = None
+        self.userEdits = []
+        self.ui_count = 0
+
+    def setImageSize(self, img_size):
+        self.img_size = img_size
+
+    def addStroke(self, prevPnt, nextPnt, color, userColor, width):
+        pass
+
+    def erasePoint(self, pnt):
+        isErase = False
+        for id, ue in enumerate(self.userEdits):
+            if ue.is_same(pnt):
+                self.userEdits.remove(ue)
+                print('remove user edit %d\n' % id)
+                isErase = True
+                break
+        return isErase
+
+    def addPoint(self, pnt, color, userColor, width):
+        self.ui_count += 1
+        print('process add Point')
+        self.userEdit = None
+        isNew = True
+        for id, ue in enumerate(self.userEdits):
+            if ue.is_same(pnt):
+                self.userEdit = ue
+                isNew = False
+                print('select user edit %d\n' % id)
+                break
+
+        if self.userEdit is None:
+            self.userEdit = PointEdit(self.win_size, self.load_size, self.img_size)
+            self.userEdits.append(self.userEdit)
+            print('add user edit %d\n' % len(self.userEdits))
+            self.userEdit.add(pnt, color, userColor, width, self.ui_count)
+            return userColor, width, isNew
+        else:
+            userColor, width = self.userEdit.select_old(pnt, self.ui_count)
+            return userColor, width, isNew
+
+    def movePoint(self, pnt, color, userColor, width):
+        self.userEdit.add(pnt, color, userColor, width, self.ui_count)
+
+    def update_color(self, color, userColor):
+        self.userEdit.update_color(color, userColor)
+
+    def update_painter(self, painter):
+        for ue in self.userEdits:
+            if ue is not None:
+                ue.update_painter(painter)
+
+    def get_stroke_image(self, im):
+        return im
+
+    def used_colors(self):  # get recently used colors
+        if len(self.userEdits) == 0:
+            return None
+        nEdits = len(self.userEdits)
+        ui_counts = np.zeros(nEdits)
+        ui_colors = np.zeros((nEdits, 3))
+        for n, ue in enumerate(self.userEdits):
+            ui_counts[n] = ue.ui_count
+            c = ue.userColor
+            ui_colors[n, :] = [c.red(), c.green(), c.blue()]
+
+        ui_counts = np.array(ui_counts)
+        ids = np.argsort(-ui_counts)
+        ui_colors = ui_colors[ids, :]
+        unique_colors = []
+        for ui_color in ui_colors:
+            is_exit = False
+            for u_color in unique_colors:
+                d = np.sum(np.abs(u_color - ui_color))
+                if d < 0.1:
+                    is_exit = True
+                    break
+
+            if not is_exit:
+                unique_colors.append(ui_color)
+
+        unique_colors = np.vstack(unique_colors)
+        return unique_colors / 255.0
+
+    def get_input(self):
+        h = self.load_size
+        w = self.load_size
+        im = np.zeros((h, w, 3), np.uint8)
+        mask = np.zeros((h, w, 1), np.uint8)
+        vis_im = np.zeros((h, w, 3), np.uint8)
+
+        for ue in self.userEdits:
+            ue.updateInput(im, mask, vis_im)
+
+        return im, mask
+
+    def reset(self):
+        self.userEdits = []
+        self.userEdit = None
+        self.ui_count = 0
diff --git a/docker/ui_PyQt5/utils.py b/docker/ui_PyQt5/utils.py
new file mode 100644
index 0000000..e262878
--- /dev/null
+++ b/docker/ui_PyQt5/utils.py
@@ -0,0 +1,108 @@
+from __future__ import print_function
+
+import inspect
+import re
+import numpy as np
+import cv2
+import os
+try:
+    import pickle as pickle
+except ImportError:
+    import pickle
+
+
+def debug_trace():
+    from PyQt5.QtCore import pyqtRemoveInputHook
+    from pdb import set_trace
+    pyqtRemoveInputHook()
+    set_trace()
+
+
+def PickleLoad(file_name):
+    try:
+        with open(file_name, 'rb') as f:
+            data = pickle.load(f)
+    except UnicodeDecodeError:
+        with open(file_name, 'rb') as f:
+            data = pickle.load(f, encoding='latin1')
+    return data
+
+
+def PickleSave(file_name, data):
+    with open(file_name, "wb") as f:
+        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+def varname(p):
+    for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
+        m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
+        if m:
+            return m.group(1)
+
+
+def print_numpy(x, val=True, shp=False):
+    x = x.astype(np.float64)
+    if shp:
+        print('shape,', x.shape)
+    if val:
+        x = x.flatten()
+        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def CVShow(im, im_name='', wait=1):
+    if len(im.shape) >= 3 and im.shape[2] == 3:
+        im_show = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+    else:
+        im_show = im
+
+    cv2.imshow(im_name, im_show)
+    cv2.waitKey(wait)
+    return im_show
+
+
+def average_image(imgs, weights):
+    im_weights = np.tile(weights[:, np.newaxis, np.newaxis, np.newaxis], (1, imgs.shape[1], imgs.shape[2], imgs.shape[3]))
+    imgs_f = imgs.astype(np.float32)
+    weights_norm = np.mean(im_weights)
+    average_f = np.mean(imgs_f * im_weights, axis=0) / weights_norm
+    average = average_f.astype(np.uint8)
+    return average
+
+
+def mkdirs(paths):
+    if isinstance(paths, list) and not isinstance(paths, str):
+        for path in paths:
+            mkdir(path)
+    else:
+        mkdir(paths)
+
+
+def mkdir(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def grid_vis(X, nh, nw):  # [buggy]
+    if X.shape[0] == 1:
+        return X[0]
+
+    # nc = 3
+    if X.ndim == 3:
+        X = X[..., np.newaxis]
+    if X.shape[-1] == 1:
+        X = np.tile(X, [1, 1, 1, 3])
+
+    h, w = X[0].shape[:2]
+
+    if X.dtype == np.uint8:
+        img = np.ones((h * nh, w * nw, 3), np.uint8) * 255
+    else:
+        img = np.ones((h * nh, w * nw, 3), X.dtype)
+
+    for n, x in enumerate(X):
+        j = n // nw
+        i = n % nw
+        img[j * h:j * h + h, i * w:i * w + w, :] = x
+    img = np.squeeze(img)
+    return img

From 1b6679afc8f72858e05d00c203607eadee9d89a1 Mon Sep 17 00:00:00 2001
From: Vishwaesh Rajiv <vrajiv@vishwaeshs-mbp.lan>
Date: Tue, 6 Aug 2019 23:09:17 -0400
Subject: [PATCH 2/2] deleted pwd and ls lines

---
 docker/Dockerfile | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/docker/Dockerfile b/docker/Dockerfile
index 277488c..77014ff 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -3,14 +3,10 @@ FROM continuumio/miniconda3
 COPY ./install/docker_conda_install.sh /app/install/
 WORKDIR /app/install/
 
-RUN pwd
-RUN ls
-
 RUN conda create -n env python=3.6
 RUN echo "source activate env" > ~/.bashrc
 ENV PATH /opt/conda/envs/env/bin:$PATH
 
-RUN ls
 RUN ./docker_conda_install.sh
 
 COPY . /app
