diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f881ae1c5a3883bb5262939af511153c54313320 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +data/* +outputs/* \ No newline at end of file diff --git a/README.md b/README.md index 3c6a32ab64b54642c13384987463d952322ee10e..f8b065a22bb9edfe212f250b1a0dddea4288da0b 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # VQ-VAE-Corn +## Introduction - -## Getting started +Corn silage images are very chaotic and hard to properly identify the elements in the picture without a good knowledge, making it hard to label and annotate images for supervised learning. VQ-VAE is an unsupervised learning method to that allows to exctract features from the image that are good enough to recreate it. To make it easy for you to get started with GitLab, here's a list of recommended next steps. diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..deb00fe60e88f9b1a1f3f6a4d5256ab80ed86292 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,17 @@ +services: + vq-vae: + build: + dockerfile: ./docker/Dockerfile + volumes: + - ./data:/etc/app/data + - ./src:/etc/app/src + - ./outputs:/etc/app/outputs + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + stdin_open: true + tty: true \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f233e49845ea2cc02cb3f03d5ff873e529b97455 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,15 @@ +FROM pytorch/pytorch:latest + +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y +RUN pip install opencv-python +RUN pip install torchsummary +RUN pip install tqdm +RUN pip install matplotlib +RUN pip install scipy +RUN pip install six +RUN pip install umap-learn +RUN pip install scikit-image +RUN apt install -y git +RUN pip install pandas + +WORKDIR /etc/app/ \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..115a509f10c8be055a98bbf078a31daae7c1c105 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,77 @@ +import glob +import os +import cv2 +import torch +import numpy as np +from tqdm import tqdm +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +from skimage import io, transform + +def get_files_path(path): + files_path = glob.glob(os.path.join(path,"**/*.jpeg"),recursive=True) + if len(files_path) == 0: + print("Invalid file path: ", path) + return files_path + +class CornDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.files = get_files_path(root_dir) + self.transform = transform + + def __len__(self): + return(len(self.files)) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.to_list() + + path = self.files[idx] + image = io.imread(path) + # image = cv2.cvtColor(cv2.imread(path),cv2.COLOR_BGR2RGB) + path_list = path.split("/") + plot = path_list[-2].upper() + if path_list[-3]=='Plot Study 1 - 09152021': + ps = "PS1" + elif path_list[-3]=='Plot Study 2 - 09202021': + ps = "PS2" + else: + print("Error loading the image label") + label = ps + "-" + plot + + sample = {'image': image, 'label': label} + if self.transform: + sample["image"] = self.transform(sample["image"]) + + return sample + +class CornDataset2(Dataset): + def __init__(self, root_dir, transform=None): + self.files = get_files_path(root_dir) + self.transform = transform + + def __len__(self): + return(len(self.files)) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.to_list() + + path = self.files[idx] + image = io.imread(path) + # image = cv2.cvtColor(cv2.imread(path),cv2.COLOR_BGR2RGB) + path_list = path.split("/") + plot = path_list[-2].upper() + if path_list[-3]=='Plot Study 1 - 09152021': + ps = "PS1" + elif path_list[-3]=='Plot Study 2 - 09202021': + ps = "PS2" + else: + print("Error loading the image label") + label = ps + "-" + plot + + sample = {'image': image, 'label': label} + if self.transform: + sample["image"] = self.transform(sample["image"]) + + return sample["image"],sample["label"] \ No newline at end of file diff --git a/src/get_encoded.py b/src/get_encoded.py new file mode 100644 index 0000000000000000000000000000000000000000..b928da1b7e40d62be28114e4d2ff33731cec3143 --- /dev/null +++ b/src/get_encoded.py @@ -0,0 +1,115 @@ +import matplotlib.pyplot as plt +import numpy as np +from scipy.signal import savgol_filter + +from six.moves import xrange + +import umap + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torch.optim as optim + +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torchvision.utils import make_grid + +from sklearn.model_selection import train_test_split + +from dataset import CornDataset + +import time + +from model import Model + +def show(img,path): + npimg = img.numpy() + fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest') + fig.axes.get_xaxis().set_visible(False) + fig.axes.get_yaxis().set_visible(False) + plt.savefig(path) + plt.clf() + +if __name__=="__main__": + print("Loading device...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device, " device loaded") + + print("Creating dataset...") + data = CornDataset("./data",transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Imagenet values + ])) + print("Dataset created!") + + # training_data, validation_data = train_test_split(data, test_size=0.33, random_state=42) + # print(training_data) + data_variance = 0.225**2 + + batch_size = 32 + num_training_updates = 10000 + + num_hiddens = 128 + num_residual_hiddens = 32 + num_residual_layers = 2 + + embedding_dim = 64 + num_embeddings = 512 + + commitment_cost = 0.25 + + decay = 0.99 + + learning_rate = 1e-3 + + print("Making loader") + loader = DataLoader(data, + batch_size=1, + shuffle=True, + pin_memory=True) + + print("Making model") + model = Model(num_hiddens, num_residual_layers, num_residual_hiddens, + num_embeddings, embedding_dim, + commitment_cost, decay).to(device) + + print("Loading model") + model.load_state_dict(torch.load("outputs/model.pt")) + model.eval() + + enc_list = [] + for d in loader: + label = d["label"] + image = d["image"] + image = image.to(device) + encoded = model._encoder(image) + output = model._pre_vq_conv(encoded) + vq = model._vq_vae(output)[1] + print(vq) + break + # valid_originals = next(iter(validation_loader))["image"] + # valid_originals = valid_originals.to(device) + + # vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals)) + # _, valid_quantize, _, _ = model._vq_vae(vq_output_eval) + # valid_reconstructions = model._decoder(valid_quantize) + + # invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], + # std = [ 1/0.229, 1/0.224, 1/0.225 ]), + # transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], + # std = [ 1., 1., 1. ]), + # ]) + + # rec_tensor = invTrans(valid_reconstructions) + # orig_tensor = invTrans(valid_originals) + + # show(make_grid(rec_tensor.cpu().data), "outputs/visualization_rec.jpeg") + # show(make_grid(orig_tensor.cpu().data), "outputs/visualization_org.jpeg") + + # proj = umap.UMAP(n_neighbors=3, + # min_dist=0.1, + # metric='cosine').fit_transform(model._vq_vae._embedding.weight.data.cpu()) + # plt.scatter(proj[:,0], proj[:,1], alpha=0.3) + # plt.savefig("outputs/embeddings.jpeg") \ No newline at end of file diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9e3efa7ed564917dfc613dd134cceaac85ce88 --- /dev/null +++ b/src/model.py @@ -0,0 +1,332 @@ +import matplotlib.pyplot as plt +import numpy as np +from scipy.signal import savgol_filter + +from six.moves import xrange + +import umap + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torch.optim as optim + +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torchvision.utils import make_grid + +from sklearn.model_selection import train_test_split + +from dataset import CornDataset + +import time + +class VectorQuantizer(nn.Module): + def __init__(self, num_embeddings, embedding_dim, commitment_cost): + super(VectorQuantizer, self).__init__() + + self._embedding_dim = embedding_dim + self._num_embeddings = num_embeddings + + self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) + self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings) + self._commitment_cost = commitment_cost + + def forward(self, inputs): + # convert inputs from BCHW -> BHWC + inputs = inputs.permute(0, 2, 3, 1).contiguous() + input_shape = inputs.shape + + # Flatten input + flat_input = inputs.view(-1, self._embedding_dim) + + # Calculate distances + distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight**2, dim=1) + - 2 * torch.matmul(flat_input, self._embedding.weight.t())) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) + + # Loss + e_latent_loss = F.mse_loss(quantized.detach(), inputs) + q_latent_loss = F.mse_loss(quantized, inputs.detach()) + loss = q_latent_loss + self._commitment_cost * e_latent_loss + + quantized = inputs + (quantized - inputs).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # convert quantized from BHWC -> BCHW + return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings + +class VectorQuantizerEMA(nn.Module): + def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): + super(VectorQuantizerEMA, self).__init__() + + self._embedding_dim = embedding_dim + self._num_embeddings = num_embeddings + + self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) + self._embedding.weight.data.normal_() + self._commitment_cost = commitment_cost + + self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) + self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) + self._ema_w.data.normal_() + + self._decay = decay + self._epsilon = epsilon + + def forward(self, inputs): + # convert inputs from BCHW -> BHWC + inputs = inputs.permute(0, 2, 3, 1).contiguous() + input_shape = inputs.shape + + # Flatten input + flat_input = inputs.view(-1, self._embedding_dim) + + # Calculate distances + distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight**2, dim=1) + - 2 * torch.matmul(flat_input, self._embedding.weight.t())) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) + + # Use EMA to update the embedding vectors + if self.training: + self._ema_cluster_size = self._ema_cluster_size * self._decay + \ + (1 - self._decay) * torch.sum(encodings, 0) + + # Laplace smoothing of the cluster size + n = torch.sum(self._ema_cluster_size.data) + self._ema_cluster_size = ( + (self._ema_cluster_size + self._epsilon) + / (n + self._num_embeddings * self._epsilon) * n) + + dw = torch.matmul(encodings.t(), flat_input) + self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) + + self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) + + # Loss + e_latent_loss = F.mse_loss(quantized.detach(), inputs) + loss = self._commitment_cost * e_latent_loss + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # convert quantized from BHWC -> BCHW + return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings + +class Residual(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_hiddens): + super(Residual, self).__init__() + self._block = nn.Sequential( + nn.ReLU(True), + nn.Conv2d(in_channels=in_channels, + out_channels=num_residual_hiddens, + kernel_size=3, stride=1, padding=1, bias=False), + nn.ReLU(True), + nn.Conv2d(in_channels=num_residual_hiddens, + out_channels=num_hiddens, + kernel_size=1, stride=1, bias=False) + ) + + def forward(self, x): + return x + self._block(x) + + +class ResidualStack(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): + super(ResidualStack, self).__init__() + self._num_residual_layers = num_residual_layers + self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) + for _ in range(self._num_residual_layers)]) + + def forward(self, x): + for i in range(self._num_residual_layers): + x = self._layers[i](x) + return F.relu(x) + +class Encoder(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): + super(Encoder, self).__init__() + + self._conv_1 = nn.Conv2d(in_channels=in_channels, + out_channels=num_hiddens//2, + kernel_size=4, + stride=2, padding=1) + self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2, + out_channels=num_hiddens, + kernel_size=4, + stride=2, padding=1) + self._conv_3 = nn.Conv2d(in_channels=num_hiddens, + out_channels=num_hiddens, + kernel_size=3, + stride=1, padding=1) + self._residual_stack = ResidualStack(in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens) + + def forward(self, inputs): + x = self._conv_1(inputs) + x = F.relu(x) + + x = self._conv_2(x) + x = F.relu(x) + + x = self._conv_3(x) + return self._residual_stack(x) + +class Decoder(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): + super(Decoder, self).__init__() + + self._conv_1 = nn.Conv2d(in_channels=in_channels, + out_channels=num_hiddens, + kernel_size=3, + stride=1, padding=1) + + self._residual_stack = ResidualStack(in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens) + + self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, + out_channels=num_hiddens//2, + kernel_size=4, + stride=2, padding=1) + + self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2, + out_channels=3, + kernel_size=4, + stride=2, padding=1) + + def forward(self, inputs): + x = self._conv_1(inputs) + + x = self._residual_stack(x) + + x = self._conv_trans_1(x) + x = F.relu(x) + + return self._conv_trans_2(x) + +class Model(nn.Module): + def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, + num_embeddings, embedding_dim, commitment_cost, decay=0): + super(Model, self).__init__() + + self._encoder = Encoder(3, num_hiddens, + num_residual_layers, + num_residual_hiddens) + self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, + out_channels=embedding_dim, + kernel_size=1, + stride=1) + if decay > 0.0: + self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, + commitment_cost, decay) + else: + self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, + commitment_cost) + self._decoder = Decoder(embedding_dim, + num_hiddens, + num_residual_layers, + num_residual_hiddens) + + def forward(self, x): + z = self._encoder(x) + z = self._pre_vq_conv(z) + loss, quantized, perplexity, _ = self._vq_vae(z) + x_recon = self._decoder(quantized) + + return loss, x_recon, perplexity + +class VariationalEncoder(nn.Module): + def __init__(self, latent_dims): + super(VariationalEncoder, self).__init__() + self.conv1 = nn.Conv2d(3, 8, 3, stride=2, padding=1) + self.conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=1) + self.batch2 = nn.BatchNorm2d(16) + self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=0) + self.linear1 = nn.Linear(3*3*32, 128) + self.linear2 = nn.Linear(128, latent_dims) + self.linear3 = nn.Linear(128, latent_dims) + + self.N = torch.distributions.Normal(0, 1) + self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU + self.N.scale = self.N.scale.cuda() + self.kl = 0 + + def forward(self, x): + # x = x.to(device) + x = F.relu(self.conv1(x)) + x = F.relu(self.batch2(self.conv2(x))) + x = F.relu(self.conv3(x)) + x = torch.flatten(x, start_dim=1) + x = F.relu(self.linear1(x)) + mu = self.linear2(x) + sigma = torch.exp(self.linear3(x)) + z = mu + sigma*self.N.sample(mu.shape) + self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() + return z + +class VariationalDecoder(nn.Module): + + def __init__(self, latent_dims): + super().__init__() + + self.decoder_lin = nn.Sequential( + nn.Linear(latent_dims, 128), + nn.ReLU(True), + nn.Linear(128, 3 * 3 * 32), + nn.ReLU(True) + ) + + self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3)) + + self.decoder_conv = nn.Sequential( + nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0), + nn.BatchNorm2d(16), + nn.ReLU(True), + nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1), + nn.BatchNorm2d(8), + nn.ReLU(True), + nn.ConvTranspose2d(8, 3, 3, stride=2, padding=1, output_padding=1) + ) + + def forward(self, x): + x = self.decoder_lin(x) + x = self.unflatten(x) + x = self.decoder_conv(x) + x = torch.sigmoid(x) + return x + +class VariationalAutoencoder(nn.Module): + def __init__(self, latent_dims,device): + super(VariationalAutoencoder, self).__init__() + self.encoder = VariationalEncoder(latent_dims) + self.decoder = VariationalDecoder(latent_dims) + self.device = device + + def forward(self, x): + x = x.to(self.device) + z = self.encoder(x) + return self.decoder(z) \ No newline at end of file diff --git a/src/test.py b/src/test.py new file mode 100644 index 0000000000000000000000000000000000000000..d8152f0a9e52a64dc5fcd30e0256c89546efc8ae --- /dev/null +++ b/src/test.py @@ -0,0 +1,110 @@ +import matplotlib.pyplot as plt +import numpy as np +from scipy.signal import savgol_filter + +from six.moves import xrange + +import umap + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torch.optim as optim + +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torchvision.utils import make_grid + +from sklearn.model_selection import train_test_split + +from dataset import CornDataset + +import time + +from model import Model + +def show(img,path): + npimg = img.numpy() + fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest') + fig.axes.get_xaxis().set_visible(False) + fig.axes.get_yaxis().set_visible(False) + plt.savefig(path) + plt.clf() + +if __name__=="__main__": + print("Loading device...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device, " device loaded") + + print("Creating dataset...") + data = CornDataset2("./data",transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Imagenet values + ])) + print("Dataset created!") + + training_data, validation_data = train_test_split(data, test_size=0.33, random_state=42) + # print(training_data) + data_variance = 0.225**2 + + batch_size = 32 + num_training_updates = 10000 + + num_hiddens = 128 + num_residual_hiddens = 32 + num_residual_layers = 2 + + embedding_dim = 64 + num_embeddings = 512 + + commitment_cost = 0.25 + + decay = 0.99 + + learning_rate = 1e-3 + + print("Making training loader") + # training_loader = DataLoader(training_data, + # batch_size=batch_size, + # shuffle=True, + # pin_memory=True) + + validation_loader = DataLoader(validation_data, + batch_size=1, + shuffle=True, + pin_memory=True) + print("Making model") + model = Model(num_hiddens, num_residual_layers, num_residual_hiddens, + num_embeddings, embedding_dim, + commitment_cost, decay).to(device) + + print("Loading model") + model.load_state_dict(torch.load("outputs/model.pt")) + model.eval() + + + valid_originals = next(iter(validation_loader))["image"] + valid_originals = valid_originals.to(device) + + vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals)) + _, valid_quantize, _, _ = model._vq_vae(vq_output_eval) + valid_reconstructions = model._decoder(valid_quantize) + + invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], + std = [ 1/0.229, 1/0.224, 1/0.225 ]), + transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], + std = [ 1., 1., 1. ]), + ]) + + rec_tensor = invTrans(valid_reconstructions) + orig_tensor = invTrans(valid_originals) + + show(make_grid(rec_tensor.cpu().data), "outputs/visualization_rec.jpeg") + show(make_grid(orig_tensor.cpu().data), "outputs/visualization_org.jpeg") + + proj = umap.UMAP(n_neighbors=3, + min_dist=0.1, + metric='cosine').fit_transform(model._vq_vae._embedding.weight.data.cpu()) + plt.scatter(proj[:,0], proj[:,1], alpha=0.3) + plt.savefig("outputs/embeddings.jpeg") \ No newline at end of file diff --git a/src/train-vae.py b/src/train-vae.py new file mode 100644 index 0000000000000000000000000000000000000000..4d564062bbd289f677f6cce1bd5aaaeee11f4200 --- /dev/null +++ b/src/train-vae.py @@ -0,0 +1,139 @@ +import matplotlib.pyplot as plt # plotting library +import numpy as np # this module is useful to work with numerical arrays +import pandas as pd +import random +import torch +import torchvision +from torchvision import transforms +from torch.utils.data import DataLoader,random_split +from torch import nn +import torch.nn.functional as F +import torch.optim as optim +from sklearn.model_selection import train_test_split + +from dataset import CornDataset2 +from model import VariationalAutoencoder + +### Training function +def train_epoch(vae, device, dataloader, optimizer): + # Set train mode for both the encoder and the decoder + vae.train() + train_loss = 0.0 + # Iterate the dataloader (we do not need the label values, this is unsupervised learning) + for data, _ in dataloader: + # print(data) + x = data + # Move tensor to the proper device + x = x.to(device) + x_hat = vae(x) + # Evaluate loss + loss = ((x - x_hat)**2).sum() + vae.encoder.kl + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + # Print batch loss + print('\t partial train loss (single batch): %f' % (loss.item())) + train_loss+=loss.item() + + return train_loss / len(dataloader.dataset) + + +### Testing function +def test_epoch(vae, device, dataloader): + # Set evaluation mode for encoder and decoder + vae.eval() + val_loss = 0.0 + with torch.no_grad(): # No need to track the gradients + for data, _ in dataloader: + # print(data) + x = data + # Move tensor to the proper device + x = x.to(device) + # Encode data + encoded_data = vae.encoder(x) + # Decode data + x_hat = vae(x) + loss = ((x - x_hat)**2).sum() + vae.encoder.kl + val_loss += loss.item() + + return val_loss / len(dataloader.dataset) + + +invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], + std = [ 1/0.229, 1/0.224, 1/0.225 ]), + transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], + std = [ 1., 1., 1. ]), + ]) + +def plot_ae_outputs(encoder,decoder,iter,n=10): + plt.figure(figsize=(16,4.5)) + targets = test_dataset.targets.numpy() + t_idx = {i:np.where(targets==i)[0][0] for i in range(n)} + for i in range(n): + ax = plt.subplot(2,n,i+1) + img = invTrans(test_dataset[t_idx[i]][0]).unsqueeze(0).to(device) + encoder.eval() + decoder.eval() + with torch.no_grad(): + rec_img = decoder(encoder(img)) + plt.imshow(img.cpu().squeeze().numpy()) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + if i == n//2: + ax.set_title('Original images') + ax = plt.subplot(2, n, i + 1 + n) + plt.imshow(invTrans(rec_img.cpu()).squeeze().numpy()) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + if i == n//2: + ax.set_title('Reconstructed images') + plt.savefig(f"outputs/vae-sample-{iter}.jpeg") + +def main(): + + print("Loading dataset...") + data = CornDataset2("./data",transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Imagenet values + ])) + + print("Splitting dataset...") + train_data, test_data = train_test_split(data, test_size=0.2, random_state=42) + train_data, val_data = train_test_split(train_data, test_size=0.25, random_state=42) + # val_data = test_data + + batch_size = 32 + + print("Creating loaders...") + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size) + valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,shuffle=True) + + torch.manual_seed(0) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + print(f'Selected devide: {device}') + + d = 256 + + vae = VariationalAutoencoder(latent_dims=d, device=device) + + lr = 1e-3 + optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-5) + + + vae.to(device) + + num_epochs = 50 + for epoch in range(num_epochs): + train_loss = train_epoch(vae,device,train_loader,optim) + val_loss = test_epoch(vae,device,valid_loader) + print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss)) + plot_ae_outputs(vae.encoder,vae.decoder,epoch,n=10) + + return 0 + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/train-vq-vae.py b/src/train-vq-vae.py new file mode 100644 index 0000000000000000000000000000000000000000..56527fbae4339f74b806a36553fdf0f0290d12b0 --- /dev/null +++ b/src/train-vq-vae.py @@ -0,0 +1,356 @@ +import matplotlib.pyplot as plt +import numpy as np +from scipy.signal import savgol_filter + +from six.moves import xrange + +import umap + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torch.optim as optim + +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torchvision.utils import make_grid + +from sklearn.model_selection import train_test_split + +from dataset import CornDataset + +import time + +class VectorQuantizer(nn.Module): + def __init__(self, num_embeddings, embedding_dim, commitment_cost): + super(VectorQuantizer, self).__init__() + + self._embedding_dim = embedding_dim + self._num_embeddings = num_embeddings + + self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) + self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings) + self._commitment_cost = commitment_cost + + def forward(self, inputs): + # convert inputs from BCHW -> BHWC + inputs = inputs.permute(0, 2, 3, 1).contiguous() + input_shape = inputs.shape + + # Flatten input + flat_input = inputs.view(-1, self._embedding_dim) + + # Calculate distances + distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight**2, dim=1) + - 2 * torch.matmul(flat_input, self._embedding.weight.t())) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) + + # Loss + e_latent_loss = F.mse_loss(quantized.detach(), inputs) + q_latent_loss = F.mse_loss(quantized, inputs.detach()) + loss = q_latent_loss + self._commitment_cost * e_latent_loss + + quantized = inputs + (quantized - inputs).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # convert quantized from BHWC -> BCHW + return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings + +class VectorQuantizerEMA(nn.Module): + def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): + super(VectorQuantizerEMA, self).__init__() + + self._embedding_dim = embedding_dim + self._num_embeddings = num_embeddings + + self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) + self._embedding.weight.data.normal_() + self._commitment_cost = commitment_cost + + self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) + self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) + self._ema_w.data.normal_() + + self._decay = decay + self._epsilon = epsilon + + def forward(self, inputs): + # convert inputs from BCHW -> BHWC + inputs = inputs.permute(0, 2, 3, 1).contiguous() + input_shape = inputs.shape + + # Flatten input + flat_input = inputs.view(-1, self._embedding_dim) + + # Calculate distances + distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight**2, dim=1) + - 2 * torch.matmul(flat_input, self._embedding.weight.t())) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and unflatten + quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) + + # Use EMA to update the embedding vectors + if self.training: + self._ema_cluster_size = self._ema_cluster_size * self._decay + \ + (1 - self._decay) * torch.sum(encodings, 0) + + # Laplace smoothing of the cluster size + n = torch.sum(self._ema_cluster_size.data) + self._ema_cluster_size = ( + (self._ema_cluster_size + self._epsilon) + / (n + self._num_embeddings * self._epsilon) * n) + + dw = torch.matmul(encodings.t(), flat_input) + self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) + + self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) + + # Loss + e_latent_loss = F.mse_loss(quantized.detach(), inputs) + loss = self._commitment_cost * e_latent_loss + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # convert quantized from BHWC -> BCHW + return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings + +class Residual(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_hiddens): + super(Residual, self).__init__() + self._block = nn.Sequential( + nn.ReLU(True), + nn.Conv2d(in_channels=in_channels, + out_channels=num_residual_hiddens, + kernel_size=3, stride=1, padding=1, bias=False), + nn.ReLU(True), + nn.Conv2d(in_channels=num_residual_hiddens, + out_channels=num_hiddens, + kernel_size=1, stride=1, bias=False) + ) + + def forward(self, x): + return x + self._block(x) + + +class ResidualStack(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): + super(ResidualStack, self).__init__() + self._num_residual_layers = num_residual_layers + self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) + for _ in range(self._num_residual_layers)]) + + def forward(self, x): + for i in range(self._num_residual_layers): + x = self._layers[i](x) + return F.relu(x) + +class Encoder(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): + super(Encoder, self).__init__() + + self._conv_1 = nn.Conv2d(in_channels=in_channels, + out_channels=num_hiddens//2, + kernel_size=4, + stride=2, padding=1) + self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2, + out_channels=num_hiddens, + kernel_size=4, + stride=2, padding=1) + self._conv_3 = nn.Conv2d(in_channels=num_hiddens, + out_channels=num_hiddens, + kernel_size=3, + stride=1, padding=1) + self._residual_stack = ResidualStack(in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens) + + def forward(self, inputs): + x = self._conv_1(inputs) + x = F.relu(x) + + x = self._conv_2(x) + x = F.relu(x) + + x = self._conv_3(x) + return self._residual_stack(x) + +class Decoder(nn.Module): + def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): + super(Decoder, self).__init__() + + self._conv_1 = nn.Conv2d(in_channels=in_channels, + out_channels=num_hiddens, + kernel_size=3, + stride=1, padding=1) + + self._residual_stack = ResidualStack(in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens) + + self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, + out_channels=num_hiddens//2, + kernel_size=4, + stride=2, padding=1) + + self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2, + out_channels=3, + kernel_size=4, + stride=2, padding=1) + + def forward(self, inputs): + x = self._conv_1(inputs) + + x = self._residual_stack(x) + + x = self._conv_trans_1(x) + x = F.relu(x) + + return self._conv_trans_2(x) + +class Model(nn.Module): + def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, + num_embeddings, embedding_dim, commitment_cost, decay=0): + super(Model, self).__init__() + + self._encoder = Encoder(3, num_hiddens, + num_residual_layers, + num_residual_hiddens) + self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, + out_channels=embedding_dim, + kernel_size=1, + stride=1) + if decay > 0.0: + self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, + commitment_cost, decay) + else: + self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, + commitment_cost) + self._decoder = Decoder(embedding_dim, + num_hiddens, + num_residual_layers, + num_residual_hiddens) + + def forward(self, x): + z = self._encoder(x) + z = self._pre_vq_conv(z) + loss, quantized, perplexity, _ = self._vq_vae(z) + x_recon = self._decoder(quantized) + + return loss, x_recon, perplexity + +if __name__ == "__main__": + + print("Loading device...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device, " device loaded") + + print("Creating dataset...") + data = CornDataset("./data",transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Imagenet values + ])) + print("Dataset created!") + + training_data, validation_data = train_test_split(data, test_size=0.33, random_state=42) + # print(training_data) + data_variance = 0.225**2 + + batch_size = 32 + num_training_updates = 10000 + + num_hiddens = 128 + num_residual_hiddens = 32 + num_residual_layers = 2 + + embedding_dim = 64 + num_embeddings = 512 + + commitment_cost = 0.25 + + decay = 0.99 + + learning_rate = 1e-3 + + print("Making training loader") + training_loader = DataLoader(training_data, + batch_size=batch_size, + shuffle=True, + pin_memory=True) + print("Making model") + model = Model(num_hiddens, num_residual_layers, num_residual_hiddens, + num_embeddings, embedding_dim, + commitment_cost, decay).to(device) + + print("Making optimizer") + optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False) + + print("Start training") + model.train() + train_res_recon_error = [] + train_res_perplexity = [] + + st = time.time() + for i in xrange(num_training_updates): + # (data, _) = next(iter(training_loader)) + data = next(iter(training_loader))["image"] + data = data.to(device) + optimizer.zero_grad() + + vq_loss, data_recon, perplexity = model(data) + recon_error = F.mse_loss(data_recon, data) / data_variance + loss = recon_error + vq_loss + loss.backward() + + optimizer.step() + + train_res_recon_error.append(recon_error.item()) + train_res_perplexity.append(perplexity.item()) + + if (i+1) % 100 == 0: + et = time.time() + print('100 iterations time: ',et-st) + st=et + print('%d iterations' % (i+1)) + print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:])) + print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:])) + print() + + torch.save(model.state_dict(),"./outputs/model.pt") + + train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7) + train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7) + + f = plt.figure(figsize=(16,8)) + ax = f.add_subplot(1,2,1) + ax.plot(train_res_recon_error_smooth) + ax.set_yscale('log') + ax.set_title('Smoothed NMSE.') + ax.set_xlabel('iteration') + + ax = f.add_subplot(1,2,2) + ax.plot(train_res_perplexity_smooth) + ax.set_title('Smoothed Average codebook usage (perplexity).') + ax.set_xlabel('iteration') + + plt.savefig("./outputs/training.jpeg") \ No newline at end of file