Over the past several years, we have seen many success stories in machine learning brought about by deep learning techniques. While the practical success of deep learning has been phenomenal, the formal guarantees have been lacking. Our current theoretical understanding of the many techniques that are central to the current ongoing big-data revolution is far from being sufficient for rigorous analysis, at best. In this episode of Data Skeptic, our host Kyle Polich welcomes guest John Wilmes, a mathematics post-doctoral researcher at Georgia Tech, to discuss the efficiency of neural network learning through complexity theory.
Deep neural networks have been successful applied to various problems in machine learning, including image classification, speech recognition, natural language processing, and autonomous driving. However, it remains puzzling why and when efficient algorithms, such as stochastic gradient methods, yield solutions that perform well. Hence, it is worth asking: what can neural networks learn efficiently in theory? What does it mean to “solve” a problem? What are the provable guarantees for deep learning algorithms? Under what condition (on input distribution, function) does Stochastic Gradient Descent work? Does it help if the data is generated by a neural networks? In other words, is the “realizable” case easier?
John and his colleagues, Le Song and Santosh Vempala at Georgia Tech and Bo Xie at Facebook, tackled these questions in their paper, titled 'On the Complexity of Learning Neural Networks.' If we are going to rely heavily on some model, we should get as much information on its limits and capabilities. When relying on a model, it is natural to wonder whether the model can solve certain things and whether it can solve them efficiently. Complexity theory recognizes that even when a problem is decidable and thus computationally solvable in principle, it may not be solvable in practice if the solution requires an inordinate amount of time.
Suppose you want to use machine learning algorithm to learn some concept, say to approximate some function, and all you know about that function are a bunch of labelled examples, or input or output pairs. How would you figure out the weights? How would you train this network? John's results suggest that, in general, there is no computationally feasible way to do this, in which there'll be be good running time guarantees. He found that no training algorithm is likely to be efficient when learning involves some computationally intractable optimization problem. The same case applies when the complexity of the hypothesis class being used to learn might be greater than the target class. John refers to an earlier study by Blum and Rivest (1992) whose results showed that even training a three node network is NP-complete. Blum and Rivet's study showed that "unless P=NP, for any polynomial-time training algorithm there will be some sets of training data on which the algorithm fails to correctly train the network, even though there exist edge weights so the network could correctly classify the data."
John and his colleagues demonstrated that comprehensive lower bound ruled any possibility that data generated by neural networks with a single hidden layer, smooth activation functions and benign input distributions can be learned efficiently. According to John, the lower bound has two parts. For the first pat, neural network updates can be viewed as statistical queries to the input distribution. Statistical query algorithms, with regards to supervised learning problems, are algorithms that can be framed in a way that it doesn’t need to interact directly with labelled examples. Second, there are many very different 1-layer networks, and in order to learn the correct one, any algorithm that makes only statistical queries of not too small accuracy has to make an exponential number of queries. Essentially, John’s results suggest that one will need exponential queries to come up with reasonable approximation; hence, it will take an incredible amount of time for the algorithm to ever learn the function that it meant to learn.