diff --git a/examples/onnx_export.ipynb b/examples/onnx_export.ipynb new file mode 100644 index 0000000..6fac752 --- /dev/null +++ b/examples/onnx_export.ipynb @@ -0,0 +1,315 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "wUYE8Wz2ojXF" + }, + "source": [ + "# ONNX export example for Neural Spline Flow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "niwtdQWHojXF" + }, + "outputs": [], + "source": [ + "# Import required packages\n", + "import torch\n", + "import numpy as np\n", + "import normflows as nf\n", + "\n", + "from sklearn.datasets import make_moons\n", + "\n", + "from matplotlib import pyplot as plt\n", + "\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ocyo6e9SojXG", + "scrolled": false + }, + "outputs": [], + "source": [ + "# Set up model\n", + "\n", + "# Define flows\n", + "K = 16\n", + "torch.manual_seed(0)\n", + "\n", + "latent_size = 2\n", + "hidden_units = 128\n", + "hidden_layers = 2\n", + "\n", + "flows = []\n", + "for i in range(K):\n", + " flows += [nf.flows.AutoregressiveRationalQuadraticSpline(latent_size, hidden_layers, hidden_units)]\n", + " flows += [nf.flows.LULinearPermute(latent_size)]\n", + "\n", + "# Set base distribuiton\n", + "q0 = nf.distributions.DiagGaussian(2, trainable=False)\n", + "\n", + "# Construct flow model\n", + "nfm = nf.NormalizingFlow(q0=q0, flows=flows)\n", + "\n", + "# Move model on GPU if available\n", + "enable_cuda = True\n", + "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n", + "nfm = nfm.to(device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "gfaPXsACojXG", + "outputId": "44658aca-aff5-41c5-a8e2-651b8abaf84a", + "scrolled": false + }, + "outputs": [], + "source": [ + "# Plot target distribution\n", + "x_np, _ = make_moons(2 ** 20, noise=0.1)\n", + "plt.figure(figsize=(15, 15))\n", + "plt.hist2d(x_np[:, 0], x_np[:, 1], bins=200)\n", + "plt.show()\n", + "\n", + "# Plot initial flow distribution\n", + "grid_size = 100\n", + "xx, yy = torch.meshgrid(torch.linspace(-1.5, 2.5, grid_size), torch.linspace(-2, 2, grid_size))\n", + "zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n", + "zz = zz.to(device)\n", + "\n", + "nfm.eval()\n", + "log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)\n", + "nfm.train()\n", + "prob = torch.exp(log_prob)\n", + "prob[torch.isnan(prob)] = 0\n", + "\n", + "plt.figure(figsize=(15, 15))\n", + "plt.pcolormesh(xx, yy, prob.data.numpy())\n", + "plt.gca().set_aspect('equal', 'box')\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "2ratRmYAojXG", + "outputId": "8359d715-53f2-4cef-fd3e-5fa6d74c80ed", + "scrolled": false + }, + "outputs": [], + "source": [ + "# Train model\n", + "max_iter = 10\n", + "num_samples = 2 ** 9\n", + "show_iter = 5\n", + "\n", + "\n", + "loss_hist = np.array([])\n", + "\n", + "optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-3, weight_decay=1e-5)\n", + "for it in tqdm(range(max_iter)):\n", + " optimizer.zero_grad()\n", + "\n", + " # Get training samples\n", + " x_np, _ = make_moons(num_samples, noise=0.1)\n", + " x = torch.tensor(x_np).float().to(device)\n", + "\n", + " # Compute loss\n", + " loss = nfm.forward_kld(x)\n", + "\n", + " # Do backprop and optimizer step\n", + " if ~(torch.isnan(loss) | torch.isinf(loss)):\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Log loss\n", + " loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n", + "\n", + " # Plot learned distribution\n", + " if (it + 1) % show_iter == 0:\n", + " nfm.eval()\n", + " log_prob = nfm.log_prob(zz)\n", + " nfm.train()\n", + " prob = torch.exp(log_prob.to('cpu').view(*xx.shape))\n", + " prob[torch.isnan(prob)] = 0\n", + "\n", + " plt.figure(figsize=(15, 15))\n", + " plt.pcolormesh(xx, yy, prob.data.numpy())\n", + " plt.gca().set_aspect('equal', 'box')\n", + " plt.show()\n", + "\n", + "# Plot loss\n", + "plt.figure(figsize=(10, 10))\n", + "plt.plot(loss_hist, label='loss')\n", + "plt.legend()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 754 + }, + "id": "KTeqlOJ-ojXG", + "outputId": "7644d64e-7f84-4d34-81be-d5d3732182a8" + }, + "outputs": [], + "source": [ + "# Plot learned distribution\n", + "nfm.eval()\n", + "log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)\n", + "nfm.train()\n", + "prob = torch.exp(log_prob)\n", + "prob[torch.isnan(prob)] = 0\n", + "\n", + "plt.figure(figsize=(15, 15))\n", + "plt.pcolormesh(xx, yy, prob.data.numpy())\n", + "plt.gca().set_aspect('equal', 'box')\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R8W1viPIovDu" + }, + "outputs": [], + "source": [ + "def aten_linalg_inv(g, arg):\n", + " return g.op(\"com.microsoft::Inverse\", arg)\n", + "\n", + "# Register custom symbolic function\n", + "torch.onnx.register_custom_op_symbolic(\"aten::linalg_inv\", aten_linalg_inv, 17)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tarbsJyKo2NU", + "outputId": "6546fd78-af31-44e5-e55b-7a5ed50635a5" + }, + "outputs": [], + "source": [ + "nfm.to(device)\n", + "\n", + "class FlowSampler(torch.nn.Module):\n", + " def __init__(self, flow_model):\n", + " super().__init__()\n", + " self.flow_model = flow_model\n", + "\n", + " def forward(self, z):\n", + " x = self.flow_model.forward(z)\n", + " return x\n", + "\n", + "sampler_model = FlowSampler(nfm)\n", + "sampler_model.eval()\n", + "\n", + "dummy_input = torch.randn(1000, latent_size, device=device)\n", + "onnx_model_path = \"onnx_sampler.onnx\"\n", + "\n", + "print(\"Exporting model to ONNX...\")\n", + "\n", + "torch.onnx.export(\n", + " sampler_model,\n", + " dummy_input,\n", + " onnx_model_path,\n", + " export_params=True,\n", + " opset_version=17,\n", + " input_names=['base_sample'],\n", + " output_names=['target_sample'],\n", + " dynamic_axes={\n", + " 'base_sample': {0: 'batch_size'},\n", + " 'target_sample': {0: 'batch_size'}\n", + " },\n", + " dynamo=False\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Q7Ttioaho_Zm", + "outputId": "ba8c17f4-f217-4955-b18c-fd4db8ba0bf4" + }, + "outputs": [], + "source": [ + "import onnxruntime as ort\n", + "\n", + "if device.type == 'cuda':\n", + " providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']\n", + "else:\n", + " providers = ['CPUExecutionProvider']\n", + "\n", + "\n", + "sess = ort.InferenceSession('onnx_sampler.onnx', providers=providers)\n", + "sample = np.random.randn(1000, 2).astype(np.float32)\n", + "output = sess.run(None, {'base_sample': sample})\n", + "output\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RRpdbSRSpvzv" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/normflows/flows/mixing.py b/normflows/flows/mixing.py index 5822836..2e2ea2b 100644 --- a/normflows/flows/mixing.py +++ b/normflows/flows/mixing.py @@ -449,27 +449,9 @@ def inverse_no_cache(self, inputs): N = num of inputs ``` """ - lower, upper = self._create_lower_upper() - outputs = inputs - self.bias - try: - outputs = torch.linalg.solve_triangular( - lower, outputs.t(), upper=False, unitriangular=True - ) - outputs = torch.linalg.solve_triangular( - upper, outputs, upper=True, unitriangular=False - ) - except: - outputs, _ = torch.triangular_solve( - outputs.t(), lower, upper=False, unitriangular=True - ) - outputs, _ = torch.triangular_solve( - outputs, upper, upper=True, unitriangular=False - ) - outputs = outputs.t() - - logabsdet = -self.logabsdet() - logabsdet = logabsdet * inputs.new_ones(outputs.shape[0]) - + W_inv = self.weight_inverse() + outputs = torch.matmul(inputs - self.bias, W_inv.t()) + logabsdet = -self.logabsdet() * inputs.new_ones(outputs.shape[0]) return outputs, logabsdet def weight(self): @@ -503,13 +485,7 @@ def weight_inverse(self): D = num of features ``` """ - lower, upper = self._create_lower_upper() - identity = torch.eye(self.features, self.features) - lower_inverse = torch.linalg.solve_triangular(lower, identity, upper=False, unitriangular=True) - weight_inverse = torch.linalg.solve_triangular( - upper, lower_inverse, upper=True, unitriangular=False - ) - return weight_inverse + return torch.inverse(self.weight()) @property def upper_diag(self): diff --git a/normflows/utils/splines.py b/normflows/utils/splines.py index ac59867..5352ace 100644 --- a/normflows/utils/splines.py +++ b/normflows/utils/splines.py @@ -26,25 +26,21 @@ def unconstrained_rational_quadratic_spline( min_derivative=DEFAULT_MIN_DERIVATIVE, ): inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) - outside_interval_mask = ~inside_interval_mask - - outputs = torch.zeros_like(inputs) - logabsdet = torch.zeros_like(inputs) if tails == "linear": - unnormalized_derivatives_ = F.pad(unnormalized_derivatives, pad=(1, 1)) + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) constant = np.log(np.exp(1 - min_derivative) - 1) - unnormalized_derivatives_[..., 0] = constant - unnormalized_derivatives_[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs_outside = inputs + logabsdet_outside = torch.zeros_like(inputs) elif tails == "circular": - unnormalized_derivatives_ = F.pad(unnormalized_derivatives, pad=(0, 1)) - unnormalized_derivatives_[..., -1] = unnormalized_derivatives_[..., 0] + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(0, 1)) + unnormalized_derivatives[..., -1] = unnormalized_derivatives[..., 0] - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 + outputs_outside = inputs + logabsdet_outside = torch.zeros_like(inputs) elif isinstance(tails, list) or isinstance(tails, tuple): unnormalized_derivatives_ = unnormalized_derivatives.clone() ind_lin = [t == "linear" for t in tails] @@ -55,29 +51,39 @@ def unconstrained_rational_quadratic_spline( unnormalized_derivatives_[..., ind_circ, -1] = unnormalized_derivatives_[ ..., ind_circ, 0 ] + unnormalized_derivatives = unnormalized_derivatives_ + + outputs_outside = inputs + logabsdet_outside = torch.zeros_like(inputs) else: raise RuntimeError("{} tails are not implemented.".format(tails)) if torch.is_tensor(tail_bound): tail_bound_ = torch.broadcast_to(tail_bound, inputs.shape) - left = -tail_bound_[inside_interval_mask] - right = tail_bound_[inside_interval_mask] - bottom = -tail_bound_[inside_interval_mask] - top = tail_bound_[inside_interval_mask] + left = -tail_bound_ + right = tail_bound_ + bottom = -tail_bound_ + top = tail_bound_ + + # Specific bounds for clamping + clamp_min = -tail_bound_ + clamp_max = tail_bound_ else: left = -tail_bound right = tail_bound bottom = -tail_bound top = tail_bound + + clamp_min = -tail_bound + clamp_max = tail_bound + + inputs_clamped = torch.clamp(inputs, min=clamp_min, max=clamp_max) - ( - outputs_masked, - logabsdet_masked - ) = rational_quadratic_spline( - inputs=inputs[inside_interval_mask], - unnormalized_widths=unnormalized_widths[inside_interval_mask, :], - unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives_[inside_interval_mask, :], + outputs_spline, logabsdet_spline = rational_quadratic_spline( + inputs=inputs_clamped, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, inverse=inverse, left=left, right=right, @@ -87,12 +93,9 @@ def unconstrained_rational_quadratic_spline( min_bin_height=min_bin_height, min_derivative=min_derivative, ) - if outputs.dtype == outputs_masked.dtype and logabsdet.dtype == logabsdet_masked.dtype: - outputs[inside_interval_mask] = outputs_masked - logabsdet[inside_interval_mask] = logabsdet_masked - else: - outputs[inside_interval_mask] = outputs_masked.to(outputs.dtype) - logabsdet[inside_interval_mask] = logabsdet_masked.to(logabsdet.dtype) + + outputs = torch.where(inside_interval_mask, outputs_spline, outputs_outside) + logabsdet = torch.where(inside_interval_mask, logabsdet_spline, logabsdet_outside) return outputs, logabsdet diff --git a/setup.py b/setup.py index 561a217..49c694f 100755 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from codecs import open from os import path -__version__ = "1.7.3" +__version__ = "1.7.4" here = path.abspath(path.dirname(__file__))