TensorTrainNumerics.jl

Tensor Train Numerics is a Julia package designed to provide efficient numerical methods for working with tensor trains (TT) and quantized tensor trains (QTT). This package offers a comprehensive set of tools for constructing, manipulating, and performing operations on tensor trains, which are useful in various scientific and engineering applications, including high-dimensional data analysis, machine learning, and computational physics.

Key features

  • Tensor Train Decomposition: Efficient algorithms for decomposing high-dimensional tensors into tensor train format, reducing computational complexity and memory usage.
  • Tensor Operations: Support for basic tensor operations such as addition, multiplication, and contraction in tensor train format.
  • Discrete Operators: Implementation of discrete Laplacians, gradient operators, and shift matrices in tensor train format for solving partial differential equations and other numerical problems.
  • Quantized Tensor Trains: Tools for constructing and manipulating quantized tensor trains, which provide further compression and efficiency for large-scale problems.
  • Iterative Solvers: Integration with iterative solvers for solving linear systems and eigenvalue problems in tensor train format.
  • Visualization: Basic visualization tools for inspecting tensor train structures and their properties.

Getting started

To get started with Tensor Train Numerics, you can install the package using Julia's package manager:

using Pkg
Pkg.add("TensorTrainNumerics")

Basic example

using TensorTrainNumerics

# Define the dimensions and ranks for the TTvector
dims = (2, 2, 2)
rks = [1, 2, 2, 1]

# Create a random TTvector
tt_vec = rand_tt(dims, rks)

# Define the dimensions and ranks for the TToperator
op_dims = (2, 2, 2)
op_rks = [1, 2, 2, 1]

# Create a random TToperator
tt_op = rand_tto(op_dims, 3)

# Perform the multiplication
result = tt_op * tt_vec

# Visualize the result

visualize(result)
 1-- • -- 4-- • -- 4-- • -- 1
     |        |        |
     2        2        2

And we can print the result

println(result)
TTvector{Float64, 3}(3, [[2.948938482500713; -4.529534351806045;;; 0.26357802309854933; -0.5448891250883173;;; 0.0396664377962828; 0.41280884615651253;;; -0.12184894836288908; -0.12654317220343364], [0.2816263160848493 0.8267753573221822 -0.33901199284112926 -1.5257394479496933; -0.36089220108714026 0.6481432079091949 -0.1508831725075234 -0.8021379030376357;;; -0.9904384899994152 -0.027616975315835374 0.7129968302020097 0.2832232333379341; 1.1830900203600967 -0.3122212851905653 -0.31139269792615365 -0.561559186474216;;; 0.22151039343681234 1.3173301112578135 -0.6630271004708472 -4.765824101767726; 0.4521077427746758 0.5373595529599737 -2.2610505903906697 -1.6424356603353671;;; -0.17640648565058484 -0.3360417781064465 -0.21528993636943006 1.3935510914698073; -0.46863362766091504 0.9330994282689975 3.128581876956841 -4.246833921485594], [-1.6817616645933104 0.9120361742812599 0.17420603823826125 0.5247212572692388; -1.063189215670791 0.22629101459053974 -0.6727046001529923 0.16495712251594544;;;]], (2, 2, 2), [1, 4, 4, 1], [0, 0, 0])

We can also unfold this

matricize(result)
8-element Vector{Float64}:
  0.05704431803498263
  0.16546137643790296
  4.560581333965011
 -6.458142643036585
 -0.1453906122309895
  0.06953869178682208
  0.5306460288121858
 -0.8011747093897214

Interpolation

We can also do interpolation in the QTT framework:

using CairoMakie
using TensorTrainNumerics

f = x -> cos(1 / (x^3 + 0.01)) + sin(π * x)
num_cores = 10
N = 150

qtt = interpolating_qtt(f, num_cores, N)
qtt_rank_revealing = lagrange_rank_revealing(f, num_cores, N)

qtt_values = matricize(qtt, num_cores)
qtt_values_rank_revealing = matricize(qtt_rank_revealing, num_cores)

x_points = LinRange(0, 1, 2^num_cores)
original_values = f.(x_points)

fig = Figure()
ax = Axis(fig[1, 1], title="Function Approximation", xlabel="x", ylabel="f(x)")

lines!(ax, x_points, original_values, label="Original Function")
lines!(ax, x_points, qtt_values_rank_revealing, label="QTT, rank rev.", linestyle=:dash, color=:green)
lines!(ax, x_points, qtt_values, label="QTT", linestyle=:dash, color=:red)

axislegend(ax)
fig
Example block output

We can visualize the interpolating QTT as

visualize(qtt)
  1-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • --  1
      |         |         |         |         |         |         |         |         |         |
      2         2         2         2         2         2         2         2         2         2

And similarly for the rank-revealing

visualize(qtt_rank_revealing)
 1-- • -- 2-- • -- 4-- • -- 8-- • --14-- • --20-- • --19-- • --15-- • --13-- • --11-- • -- 1
     |        |        |        |        |        |        |        |        |        |
     2        2        2        2        2        2        2        2        2        2