TensorTrainNumerics.jl
Tensor Train Numerics is a Julia package designed to provide efficient numerical methods for working with tensor trains (TT) and quantized tensor trains (QTT). This package offers a comprehensive set of tools for constructing, manipulating, and performing operations on tensor trains, which are useful in various scientific and engineering applications, including high-dimensional data analysis, machine learning, and computational physics.
Key features
- Tensor Train Decomposition: Efficient algorithms for decomposing high-dimensional tensors into tensor train format, reducing computational complexity and memory usage.
- Tensor Operations: Support for basic tensor operations such as addition, multiplication, and contraction in tensor train format.
- Discrete Operators: Implementation of discrete Laplacians, gradient operators, and shift matrices in tensor train format for solving partial differential equations and other numerical problems.
- Quantized Tensor Trains: Tools for constructing and manipulating quantized tensor trains, which provide further compression and efficiency for large-scale problems.
- Iterative Solvers: Integration with iterative solvers for solving linear systems and eigenvalue problems in tensor train format.
- Visualization: Basic visualization tools for inspecting tensor train structures and their properties.
Getting started
To get started with Tensor Train Numerics, you can install the package using Julia's package manager:
using Pkg
Pkg.add("TensorTrainNumerics")
Basic example
using TensorTrainNumerics
# Define the dimensions and ranks for the TTvector
dims = (2, 2, 2)
rks = [1, 2, 2, 1]
# Create a random TTvector
tt_vec = rand_tt(dims, rks)
# Define the dimensions and ranks for the TToperator
op_dims = (2, 2, 2)
op_rks = [1, 2, 2, 1]
# Create a random TToperator
tt_op = rand_tto(op_dims, 3)
# Perform the multiplication
result = tt_op * tt_vec
# Visualize the result
visualize(result)
1-- • -- 4-- • -- 4-- • -- 1
| | |
2 2 2
And we can print the result
println(result)
TTvector{Float64, 3}(3, [[2.948938482500713; -4.529534351806045;;; 0.26357802309854933; -0.5448891250883173;;; 0.0396664377962828; 0.41280884615651253;;; -0.12184894836288908; -0.12654317220343364], [0.2816263160848493 0.8267753573221822 -0.33901199284112926 -1.5257394479496933; -0.36089220108714026 0.6481432079091949 -0.1508831725075234 -0.8021379030376357;;; -0.9904384899994152 -0.027616975315835374 0.7129968302020097 0.2832232333379341; 1.1830900203600967 -0.3122212851905653 -0.31139269792615365 -0.561559186474216;;; 0.22151039343681234 1.3173301112578135 -0.6630271004708472 -4.765824101767726; 0.4521077427746758 0.5373595529599737 -2.2610505903906697 -1.6424356603353671;;; -0.17640648565058484 -0.3360417781064465 -0.21528993636943006 1.3935510914698073; -0.46863362766091504 0.9330994282689975 3.128581876956841 -4.246833921485594], [-1.6817616645933104 0.9120361742812599 0.17420603823826125 0.5247212572692388; -1.063189215670791 0.22629101459053974 -0.6727046001529923 0.16495712251594544;;;]], (2, 2, 2), [1, 4, 4, 1], [0, 0, 0])
We can also unfold this
matricize(result)
8-element Vector{Float64}:
0.05704431803498263
0.16546137643790296
4.560581333965011
-6.458142643036585
-0.1453906122309895
0.06953869178682208
0.5306460288121858
-0.8011747093897214
Interpolation
We can also do interpolation in the QTT framework:
using CairoMakie
using TensorTrainNumerics
f = x -> cos(1 / (x^3 + 0.01)) + sin(π * x)
num_cores = 10
N = 150
qtt = interpolating_qtt(f, num_cores, N)
qtt_rank_revealing = lagrange_rank_revealing(f, num_cores, N)
qtt_values = matricize(qtt, num_cores)
qtt_values_rank_revealing = matricize(qtt_rank_revealing, num_cores)
x_points = LinRange(0, 1, 2^num_cores)
original_values = f.(x_points)
fig = Figure()
ax = Axis(fig[1, 1], title="Function Approximation", xlabel="x", ylabel="f(x)")
lines!(ax, x_points, original_values, label="Original Function")
lines!(ax, x_points, qtt_values_rank_revealing, label="QTT, rank rev.", linestyle=:dash, color=:green)
lines!(ax, x_points, qtt_values, label="QTT", linestyle=:dash, color=:red)
axislegend(ax)
fig

We can visualize the interpolating QTT as
visualize(qtt)
1-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • --151-- • -- 1
| | | | | | | | | |
2 2 2 2 2 2 2 2 2 2
And similarly for the rank-revealing
visualize(qtt_rank_revealing)
1-- • -- 2-- • -- 4-- • -- 8-- • --14-- • --20-- • --19-- • --15-- • --13-- • --11-- • -- 1
| | | | | | | | | |
2 2 2 2 2 2 2 2 2 2