Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions examples/onnx_export.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "wUYE8Wz2ojXF"
},
"source": [
"# ONNX export example for Neural Spline Flow"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "niwtdQWHojXF"
},
"outputs": [],
"source": [
"# Import required packages\n",
"import torch\n",
"import numpy as np\n",
"import normflows as nf\n",
"\n",
"from sklearn.datasets import make_moons\n",
"\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from tqdm import tqdm\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ocyo6e9SojXG",
"scrolled": false
},
"outputs": [],
"source": [
"# Set up model\n",
"\n",
"# Define flows\n",
"K = 16\n",
"torch.manual_seed(0)\n",
"\n",
"latent_size = 2\n",
"hidden_units = 128\n",
"hidden_layers = 2\n",
"\n",
"flows = []\n",
"for i in range(K):\n",
" flows += [nf.flows.AutoregressiveRationalQuadraticSpline(latent_size, hidden_layers, hidden_units)]\n",
" flows += [nf.flows.LULinearPermute(latent_size)]\n",
"\n",
"# Set base distribuiton\n",
"q0 = nf.distributions.DiagGaussian(2, trainable=False)\n",
"\n",
"# Construct flow model\n",
"nfm = nf.NormalizingFlow(q0=q0, flows=flows)\n",
"\n",
"# Move model on GPU if available\n",
"enable_cuda = True\n",
"device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
"nfm = nfm.to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "gfaPXsACojXG",
"outputId": "44658aca-aff5-41c5-a8e2-651b8abaf84a",
"scrolled": false
},
"outputs": [],
"source": [
"# Plot target distribution\n",
"x_np, _ = make_moons(2 ** 20, noise=0.1)\n",
"plt.figure(figsize=(15, 15))\n",
"plt.hist2d(x_np[:, 0], x_np[:, 1], bins=200)\n",
"plt.show()\n",
"\n",
"# Plot initial flow distribution\n",
"grid_size = 100\n",
"xx, yy = torch.meshgrid(torch.linspace(-1.5, 2.5, grid_size), torch.linspace(-2, 2, grid_size))\n",
"zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n",
"zz = zz.to(device)\n",
"\n",
"nfm.eval()\n",
"log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)\n",
"nfm.train()\n",
"prob = torch.exp(log_prob)\n",
"prob[torch.isnan(prob)] = 0\n",
"\n",
"plt.figure(figsize=(15, 15))\n",
"plt.pcolormesh(xx, yy, prob.data.numpy())\n",
"plt.gca().set_aspect('equal', 'box')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "2ratRmYAojXG",
"outputId": "8359d715-53f2-4cef-fd3e-5fa6d74c80ed",
"scrolled": false
},
"outputs": [],
"source": [
"# Train model\n",
"max_iter = 10\n",
"num_samples = 2 ** 9\n",
"show_iter = 5\n",
"\n",
"\n",
"loss_hist = np.array([])\n",
"\n",
"optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-3, weight_decay=1e-5)\n",
"for it in tqdm(range(max_iter)):\n",
" optimizer.zero_grad()\n",
"\n",
" # Get training samples\n",
" x_np, _ = make_moons(num_samples, noise=0.1)\n",
" x = torch.tensor(x_np).float().to(device)\n",
"\n",
" # Compute loss\n",
" loss = nfm.forward_kld(x)\n",
"\n",
" # Do backprop and optimizer step\n",
" if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Log loss\n",
" loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
"\n",
" # Plot learned distribution\n",
" if (it + 1) % show_iter == 0:\n",
" nfm.eval()\n",
" log_prob = nfm.log_prob(zz)\n",
" nfm.train()\n",
" prob = torch.exp(log_prob.to('cpu').view(*xx.shape))\n",
" prob[torch.isnan(prob)] = 0\n",
"\n",
" plt.figure(figsize=(15, 15))\n",
" plt.pcolormesh(xx, yy, prob.data.numpy())\n",
" plt.gca().set_aspect('equal', 'box')\n",
" plt.show()\n",
"\n",
"# Plot loss\n",
"plt.figure(figsize=(10, 10))\n",
"plt.plot(loss_hist, label='loss')\n",
"plt.legend()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 754
},
"id": "KTeqlOJ-ojXG",
"outputId": "7644d64e-7f84-4d34-81be-d5d3732182a8"
},
"outputs": [],
"source": [
"# Plot learned distribution\n",
"nfm.eval()\n",
"log_prob = nfm.log_prob(zz).to('cpu').view(*xx.shape)\n",
"nfm.train()\n",
"prob = torch.exp(log_prob)\n",
"prob[torch.isnan(prob)] = 0\n",
"\n",
"plt.figure(figsize=(15, 15))\n",
"plt.pcolormesh(xx, yy, prob.data.numpy())\n",
"plt.gca().set_aspect('equal', 'box')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R8W1viPIovDu"
},
"outputs": [],
"source": [
"def aten_linalg_inv(g, arg):\n",
" return g.op(\"com.microsoft::Inverse\", arg)\n",
"\n",
"# Register custom symbolic function\n",
"torch.onnx.register_custom_op_symbolic(\"aten::linalg_inv\", aten_linalg_inv, 17)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tarbsJyKo2NU",
"outputId": "6546fd78-af31-44e5-e55b-7a5ed50635a5"
},
"outputs": [],
"source": [
"nfm.to(device)\n",
"\n",
"class FlowSampler(torch.nn.Module):\n",
" def __init__(self, flow_model):\n",
" super().__init__()\n",
" self.flow_model = flow_model\n",
"\n",
" def forward(self, z):\n",
" x = self.flow_model.forward(z)\n",
" return x\n",
"\n",
"sampler_model = FlowSampler(nfm)\n",
"sampler_model.eval()\n",
"\n",
"dummy_input = torch.randn(1000, latent_size, device=device)\n",
"onnx_model_path = \"onnx_sampler.onnx\"\n",
"\n",
"print(\"Exporting model to ONNX...\")\n",
"\n",
"torch.onnx.export(\n",
" sampler_model,\n",
" dummy_input,\n",
" onnx_model_path,\n",
" export_params=True,\n",
" opset_version=17,\n",
" input_names=['base_sample'],\n",
" output_names=['target_sample'],\n",
" dynamic_axes={\n",
" 'base_sample': {0: 'batch_size'},\n",
" 'target_sample': {0: 'batch_size'}\n",
" },\n",
" dynamo=False\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Q7Ttioaho_Zm",
"outputId": "ba8c17f4-f217-4955-b18c-fd4db8ba0bf4"
},
"outputs": [],
"source": [
"import onnxruntime as ort\n",
"\n",
"if device.type == 'cuda':\n",
" providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']\n",
"else:\n",
" providers = ['CPUExecutionProvider']\n",
"\n",
"\n",
"sess = ort.InferenceSession('onnx_sampler.onnx', providers=providers)\n",
"sample = np.random.randn(1000, 2).astype(np.float32)\n",
"output = sess.run(None, {'base_sample': sample})\n",
"output\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RRpdbSRSpvzv"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
32 changes: 4 additions & 28 deletions normflows/flows/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,27 +449,9 @@ def inverse_no_cache(self, inputs):
N = num of inputs
```
"""
lower, upper = self._create_lower_upper()
outputs = inputs - self.bias
try:
outputs = torch.linalg.solve_triangular(
lower, outputs.t(), upper=False, unitriangular=True
)
outputs = torch.linalg.solve_triangular(
upper, outputs, upper=True, unitriangular=False
)
except:
outputs, _ = torch.triangular_solve(
outputs.t(), lower, upper=False, unitriangular=True
)
outputs, _ = torch.triangular_solve(
outputs, upper, upper=True, unitriangular=False
)
outputs = outputs.t()

logabsdet = -self.logabsdet()
logabsdet = logabsdet * inputs.new_ones(outputs.shape[0])

W_inv = self.weight_inverse()
outputs = torch.matmul(inputs - self.bias, W_inv.t())
logabsdet = -self.logabsdet() * inputs.new_ones(outputs.shape[0])
return outputs, logabsdet

def weight(self):
Expand Down Expand Up @@ -503,13 +485,7 @@ def weight_inverse(self):
D = num of features
```
"""
lower, upper = self._create_lower_upper()
identity = torch.eye(self.features, self.features)
lower_inverse = torch.linalg.solve_triangular(lower, identity, upper=False, unitriangular=True)
weight_inverse = torch.linalg.solve_triangular(
upper, lower_inverse, upper=True, unitriangular=False
)
return weight_inverse
return torch.inverse(self.weight())

@property
def upper_diag(self):
Expand Down
Loading