If you rotate an image of a molecular structure, a human can tell the rotated image is still the same molecule, but a machine-learning model might think it is a new data point. In computer science parlance, the molecule is “symmetric,” meaning the fundamental structure of that molecule remains the same if it undergoes certain transformations, like rotation.
If a drug discovery model doesn’t understand symmetry, it could make inaccurate predictions about molecular properties. But despite some empirical successes, it’s been unclear whether there is a computationally efficient method to train a good model that is guaranteed to respect symmetry.
A new study by MIT researchers answers this question, and shows the first method for machine learning with symmetry that is provably efficient in terms of both the amount of computation and data needed.
These results clarify a foundational question, and they could aid researchers in the development of more powerful machine-learning models that are designed to handle symmetry. Such models would be useful in a variety of applications, from discovering new materials to identifying astronomical anomalies to unraveling complex climate patterns.
“These symmetries are important because they are some sort of information that nature is telling us about the data, and we should take it into account in our machine-learning models. We’ve now shown that it is possible to do machine-learning with symmetric data in an efficient way,” says Behrooz Tahmasebi, an MIT graduate student and co-lead author of this study.
He is joined on the paper by co-lead author and MIT graduate student Ashkan Soleymani; Stefanie Jegelka, an associate professor of electrical engineering and computer science (EECS) and a member of the Institute for Data, Systems, and Society (IDSS) and the Computer Science and Artificial Intelligence Laboratory (CSAIL); and senior author Patrick Jaillet, the Dugald C. Jackson Professor of Electrical Engineering and Computer Science and a principal investigator in the Laboratory for Information and Decision Systems (LIDS). The research was recently presented at the International Conference on Machine Learning.
Studying symmetry
Symmetric data appear in many domains, especially the natural sciences and physics. A model that recognizes symmetries is able to identify an object, like a car, no matter where that object is placed in an image, for example.
Unless a machine-learning model is designed to handle symmetry, it could be less accurate and prone to failure when faced with new symmetric data in real-world situations. On the flip side, models that take advantage of symmetry could be faster and require fewer data for training.
But training a model to process symmetric data is no easy task.
One common approach is called data augmentation, where researchers transform each symmetric data point into multiple data points to help the model generalize better to new data. For instance, one could rotate a molecular structure many times to produce new training data, but if researchers want the model to be guaranteed to respect symmetry, this can be computationally prohibitive.
An alternative approach is to encode symmetry into the model’s architecture. A well-known example of this is a graph neural network (GNN), which inherently handles symmetric data because of how it is designed.
“Graph neural networks are fast and efficient, and they take care of symmetry quite well, but nobody really knows what these models are learning or why they work. Understanding GNNs is a main motivation of our work, so we started with a theoretical evaluation of what happens when data are symmetric,” Tahmasebi says.
They explored the statistical-computational tradeoff in machine learning with symmetric data. This tradeoff means methods that require fewer data can be more computationally expensive, so researchers need to find the right balance.
Building on this theoretical evaluation, the researchers designed an efficient algorithm for machine learning with symmetric data.
Mathematical combinations
To do this, they borrowed ideas from algebra to shrink and simplify the problem. Then, they reformulated the problem using ideas from geometry that effectively capture symmetry.
Finally, they combined the algebra and the geometry into an optimization problem that can be solved efficiently, resulting in their new algorithm.
“Most of the theory and applications were focusing on either algebra or geometry. Here we just combined them,” Tahmasebi says.
The algorithm requires fewer data samples for training than classical approaches, which would improve a model’s accuracy and ability to adapt to new applications.
By proving that scientists can develop efficient algorithms for machine learning with symmetry, and demonstrating how it can be done, these results could lead to the development of new neural network architectures that could be more accurate and less resource-intensive than current models.
Scientists could also use this analysis as a starting point to examine the inner workings of GNNs, and how their operations differ from the algorithm the MIT researchers developed.
“Once we know that better, we can design more interpretable, more robust, and more efficient neural network architectures,” adds Soleymani.
This research is funded, in part, by the National Research Foundation of Singapore, DSO National Laboratories of Singapore, the U.S. Office of Naval Research, the U.S. National Science Foundation, and an Alexander von Humboldt Professorship.