Rips Optimization

Optimzation on data points to increase/decrease polynomial features defined on persistent homology built from Rips Filtration.

[1]:
import time
import torch
import torch.nn as nn
import torch_tda
import numpy as np
import matplotlib.pyplot as plt
import bats
from tqdm import tqdm
[2]:
n = 100
np.random.seed(0)
data = np.random.uniform(0,1,(n,2))
fig1 = plt.scatter(data[:,0], data[:,1])
fig1.axes.set_aspect('equal')
X = torch.tensor(data, requires_grad=True)
../_images/examples_rips_opt_2_0.png
[3]:
# Compute H1 and H0
# maximum homology dimension is 1, which implies C_2 needed
# flags = (bats.standard_reduction_flag(), bats.compute_basis_flag())
flags = (bats.standard_reduction_flag(),bats.compression_flag())
# flags = ()
layer = torch_tda.nn.RipsLayer(maxdim = 1, reduction_flags=flags)
# dgms = layer(X) # run FlagDiagram.forward()
[4]:
f1 = torch_tda.nn.BarcodePolyFeature(1,2,0)
[5]:
optimizer = torch.optim.Adam([X], lr=1e-2)
for i in tqdm(range(100)):
    optimizer.zero_grad()
    dgms = layer(X)
    loss = -f1(dgms)
    loss.backward()
    optimizer.step()
100%|██████████| 100/100 [00:19<00:00,  5.15it/s]
[6]:
# save figure
y = X.detach().numpy()
fig, ax = plt.subplots(ncols=2, figsize=(10,5))
ax[0].scatter(data[:,0], data[:,1])
ax[0].set_title("Before")
ax[1].scatter(y[:,0], y[:,1])
ax[1].set_title("After")
for i in range(2):
    ax[i].set_yticklabels([])
    ax[i].set_xticklabels([])
    ax[i].tick_params(bottom=False, left=False)
plt.savefig('holes.png')
../_images/examples_rips_opt_6_0.png
[ ]: