Source: The Gradient
What is the Role of Mathematics in Modern Machine Learning?
The past decade has witnessed a shift in how progress is made in machine learning. Research involving carefully designed and mathematically principled architectures result in only marginal improvements while compute-intensive and engineering-first efforts that scale to ever larger training sets and model parameter counts result in remarkable new capabilities unpredicted by existing theory. Mathematics and statistics, once the primary guides of machine learning research, now struggle to provide immediate insight into the latest breakthroughs. This is not the first time that empirical progress in machine learning has outpaced more theory-motivated approaches, yet the magnitude of recent advances has forced us to swallow the bitter pill of the “Bitter Lesson” yet again [1].
This shift has prompted speculation about mathematics’ diminished role in machine learning research moving forward. It is already evident that mathematics will have to share the stage with a broader range of perspectives (for instance, biology which has deep experience drawing conclusions about irreducibly complex systems or the social sciences as AI is integrated ever more deeply into society). The increasingly interdisciplinary nature of machine learning should be welcomed as a positive development by all researchers.
However, we argue that mathematics remains as relevant as ever; its role is simply evolving. For example, whereas mathematics might once have primarily provided theoretical guarantees on model performance, it may soon be more commonly used for post-hoc explanations of empirical phenomena observed in model training and performance–a role analogous to one that it plays in physics. Similarly, while mathematical intuition might once have guided the design of handcrafted features or architectural details at a granular level, its use may shift to higher-level design choices such as matching architecture to underlying task structure or data symmetries.
None of this is completely new. Mathematics has always served multiple purposes in machine learning. After all, the translation equivariant convolutional neural network, which exemplifies the idea of architecture matching data symmetries mentioned above is now over 40 years old. What’s changing are the kinds of problems where mathematics will have the greatest impact and the ways it will most commonly be applied.
An intriguing consequence of the shift towards scale is that it has broadened the scope of the fields of mathematics applicable to machine learning. “Pure” mathematical domains such as topology, algebra, and geometry, are now joining the more traditionally applied fields of probability theory, analysis, and linear algebra. These pure fields have grown and developed over the last century to handle high levels of abstraction and complexity, helping mathematicians make discoveries about spaces, algebraic objects, and combinatorial processes that at first glance seem beyond human intuition. These capabilities promise to address many of the biggest challenges in modern deep learning.
In this article we will explore several areas of current research that demonstrate the enduring ability of mathematics to guide the process of discovery and understanding in machine learning.
Figure 1: Mathematics can illuminate the ways that ReLU-based neural networks shatter input space into countless polygonal regions, in each of which the model behaves like a linear map [2, 3, 4]. These decompositions create beautiful patterns. (Figure made with SplineCam [5]).
Describing an Elephant from a Pin Prick
Suppose you are given a 7 billion parameter neural network with 50 layers and are asked to analyze it; how would you begin? The standard procedure would be to calculate relevant performance statistics. For instance, the accuracy on a suite of evaluation benchmarks. In certain situations, this may be sufficient. However, deep learning models are complex and multifaceted. Two computer vision models with the same accuracy may have very different generalization properties to out-of-distribution data, calibration, adversarial robustness, and other “secondary statistics” that are critical in many real-world applications. Beyond this, all evidence suggests that to build a complete scientific understanding of deep learning, we will need to venture beyond evaluation scores. Indeed, just as it is impossible to capture all the dimensions of humanity with a single numerical quantity (e.g., IQ, height), trying to understand a model by one or even several statistics alone is fundamentally limiting.
One difference between understanding a human and understanding a model is that we have easy access to all model parameters and all the individual computations that occur in a model. Indeed, by extracting a model’s hidden activations we can directly trace the process by which a model converts raw input into a prediction. Unfortunately, the world of hidden activations is far less hospitable than that of simple model performance statistics. Like the initial input, hidden activations are usually high dimensional, but unlike input data they are not structured in a form that humans can understand. If we venture into even higher dimensions, we can try to understand a model through its weights directly. Here, in the space of model weights, we have the freedom to move in millions to billions of orthogonal directions from a single starting point. How do we even begin to make sense of these worlds?
There is a well-known fable in which three blind men each feel a different part of an elephant. The description that each gives of the animal is completely different, reflecting only the body part that that man felt. We argue that unlike the blind men who can at least use their hand to feel a substantial part of one of the elephant’s body parts, current methods of analyzing the hidden activations and weights of a model are akin to trying to describe the elephant from the touch of a single pin.
Tools to Characterize What We Cannot Visualize
Despite the popular perception that mathematicians exclusively focus on solving problems, much of research mathematics involves understanding the right questions to ask in the first place. This is natural since many of the objects that mathematicians study are so far removed from everyday experience that we start with very limited intuition for what we can hope to actually understand. Substantial effort is often required to build up tools that will enable us to leverage our existing intuition and achieve tractable results that increase our understanding. The concept of a rotation provides a nice example of this situation since these are very familiar in 2- and 3-dimensions, but become less and less accessible to everyday intuition as their dimension grows larger. In this latter case, the differing perspectives provided by pure mathematics become more and more important to gaining a more holistic perspective on what these actually are.
Those who know a little linear algebra will remember that rotations generalize to higher dimensions and that in $n$-dimensions they can be realized by $n \times n$ orthogonal matrices with determinant $1$. The set of these are commonly written as $SO(n)$ and called the special orthogonal group. Suppose we want to understand the set of all $n$-dimensional rotations. There are many complementary approaches to doing this. We can explore the linear algebraic structure of all matrices in $SO(n)$ or study $SO(n)$ based on how each element behaves as an operator acting on $\mathbb{R}^n$.
Alternatively, we can also try to use our innate spatial intuition to understand $SO(n)$. This turns out to be a powerful perspective in math. In any dimension $n$, $SO(n)$ is a geometric object called a manifold. Very roughly, a space that locally looks like Euclidean space, but which may have twists, holes, and other non-Euclidean features when we zoom out. Indeed, whether we make it precise or not, we all have a sense of whether two rotations are “close” to each other. For example, the reader would probably agree that $2$-dimensional rotations of $90^\circ$ and $91^\circ$ “feel” closer than rotations of $90^\circ$ and $180^\circ$. When $n=2$, one can show that the set of all rotations is geometrically “equivalent” to a $1$-dimensional circle. So, much of what we know about the circle can be translated to $SO(2)$.
What happens when we want to study the geometry of rotations in $n$-dimensions for $n > 3$? If $n = 512$ (a latent space for instance), this amounts to studying a manifold in $512^2$-dimensional space. Our visual intuition is seemingly useless here since it is not clear how concepts that are familiar in 2- and 3-dimensions can be utilized in $512^2$-dimensions. Mathematicians have been confronting the problem of understanding the un-visualizable for hundreds of years. One strategy is to find generalizations of familiar spatial concepts from $2$ and $3$-dimensions to $n$-dimensions that connect with our intuition.
This approach is already being used to better understand and characterize experimental observations about the space of model weights, hidden activations, and input data of deep learning models. We provide a taste of such tools and applications here:
- Intrinsic Dimension: Dimension is a concept that is familiar not only from our experience in the spatial dimensions that we can readily access, 1-, 2-, and 3-dimensions, but also from more informal notions of “degrees of freedom” in everyday systems such as driving a car (forward/back, turning the steering wheel either left or right). The notion of dimension arises naturally in the context of machine learning where we may want to capture the number of independent ways in which a dataset, learned representation, or collection of weight matrices actually vary.
In formal mathematics, the definitions of dimension depend on the kind of space one is studying but they all capture some aspect of this everyday intuition. As a simple example, if I walk along the perimeter of a circle, I am only able to move forward and backward, and thus the dimension of this space is $1$. For spaces like the circle which are manifolds, dimension can be formally defined by the fact that a sufficiently small neighborhood around each point looks like a subset of some Euclidean space $\mathbb{R}^k$. We then say that the manifold is $k$-dimensional. If we zoom in on a small segment of the circle, it almost looks like a segment of $\mathbb{R} = \mathbb{R}^1$, and hence the circle is $1$-dimensional.
The manifold hypothesis posits that many types of data (at least approximately) live on a low-dimensional manifold even though they are embedded in a high-dimensional space. If we assume that this is true, it makes sense that the dimension of this underlying manifold, called the intrinsic dimension of the data, is one way to describe the complexity of the dataset. Researchers have estimated intrinsic dimension for common benchmark datasets, showing that intrinsic dimension appears to be correlated to the ease with which models generalize from training to test sets [6], and can explain differences in model performance and robustness in different domains such as medical images [7]. Intrinsic dimension is also a fundamental ingredient in some proposed explanations of data scaling laws [8, 9], which underlie the race to build ever bigger generative models.
Support authors and subscribe to content
This is premium stuff. Subscribe to read the entire article.