Show HN: Boldly go where Gradient Descent has never gone before with DiscoGrad
232 points| frankling_ | 1 year ago |github.com
We develped DiscoGrad, a tool for automatic differentiation through C++ programs involving input-dependent control flow (e.g., "if (f(x) < c) { ... }", differentiating wrt. x) and randomness. Our initial motivation was to enable the use of gradient descent with simulations, which often rely heavily on such discrete branching. The latter makes plain autodiff mostly useless, since it can only account for the single path taken through the program. Our tool offers several backends that handle this situation, giving useful descent directions for optimization by accounting for alternative branches. Besides simulations, this problem arises in many other places, for example in deep learning when trying to combine imperative programs with neural networks.
In a nutshell, DiscoGrad applies an (LLVM-based) source-to-source transformation to your C++ program, adding some calls to our header library, which then handles the gradient computation. What sets it apart from similar tools/estimators is that it's fully automatic (no need to come up with a differentiable problem formulation/reparametrization) and that the branching condition can be any function of the program inputs (no need to know upfront what distribution the condition follows).
We're currently a team of two working on DiscoGrad as part of a research project, so don't expect to see production-grade code quality, but we do intend for it to be more than a throwaway research prototype. Use cases we've successfully tested include calibrating simulation models of epidemics or evacuation scenarios via gradient descent, and combining simulations with neural networks in an end-to-end trainable fashion.
We hope you find this interesting and useful, and we're happy to answer questions!
eranation|1 year ago
EDIT: removed part of the question that is answered in the article.
frankling_|1 year ago
As an example, the very first thing we looked into was a transportation engineering problem, where the red/green phases of traffic lights lead to a non-smooth optimization problem. In essence, in that case we were looking for the "best possible" parameters for a transportation simulation (in the form of a C++ program) that's full of branches.
PeterHolzwarth|1 year ago
vessenes|1 year ago
In all seriousness, this is super interesting. I really like the idea of implementing gradient descent solving branch by branch, and turning it into an optimization-level option for codebases.
I feel like this would normally be something commercialized by Intel's compiler group; it's hard for me to know how to get this out more broadly -- it would probably need to be standardized in some way?
Anyway, thanks for working on it and opening it up -- very cool. Needs more disco balls.
frankling_|1 year ago
We were thinking of some disco ball-based logo (among some other designs). With this encouragement, there'll probably be an update in the next few days :)
dwrodri|1 year ago
I remember being taught how to write Prolog in University, and then being shown how close the relationship was between building something that parses a grammar and building something that generates valid examples of that grammar. When I saw compiler/language level support for differentiation, I the spark went off in my brain the same way: "If you can build a program which follows a set of rules, and the rules for that language can be differentiated, could you not code a simulation in that differentiable language and then identify the optimal policy using it's gradients?"
Best of luck on your work!
justinnk|1 year ago
YeGoblynQueenne|1 year ago
What's a "policy" here? In optimal control (and reinforcement learning) a policy is a function from a set of states to a set of actions, each action a transition between states. In a program synthesis context I guess that translates to a function from a set of _program_ states to a set of operations?
What is an "optimal" policy then? One that transitions between an initial state and a goal state in the least number of operations?
With those assumptions in place, I don't think you want to do that with greadient descent: it will get stuck in local minima and fail in both optimality and generalisation.
Generalisation is easier to explain. Consider a program that has to traverse a graph. We can visualise it as solving a maze. Suppose we have two mazes, A and B, as below:
Black squares are walls. Note that the two mazes are identical but the exit ("E") is in a different place. An optimal policy that solves maze A will fail on maze B and v.v. Meaning that for some classes of problem there is no policy that is optimal for the every instance in the class and finding an optimal solution requires computation. You can't just set some weights in a function and call it a day.It's also easy to see what classes of problems are not amenable to this kind of solution: any decision problem that cannot be solved by a regular automaton (i.e. one that is no more than regular). Where there's branching structure that introduces ambiguity -think of two different parses for one string in a language- you need a context-free grammar or above.
That's a problem in Reinforcement Learning where "agents" (i.e. policies) can solve any instance of complex environment classes perfectly, but fail when tested in a different instance [1].
You'll get the same problem with program synthesis.
___________
[1] This paper:
Why Generalization in RL is Difficult: Epistemic POMDPs and Implicit Partial Observability
https://arxiv.org/abs/2107.06277
makes the point with what felt like a very convoluted example about a robotic zoo keeper looking for the otter habitat in a new zoo etc. I think it's much more obvious what's going on when we study the problem in a grid like a maze: there are ambiguities and a solution cannot be left to a policy that acts like a regular automaton.
usgroup|1 year ago
You can write down just about anything as a BUGS model for example, but “identifying the model” —- finding the uniquely best parameters, even though it’s a globally optimisation —- is often very difficult.
Gradient descent is significantly more limiting than that. Worth understanding MC. The old school is a high bar to jump.
szvsw|1 year ago
Can you talk a little bit about the challenges of bringing something like what you’ve implemented to existing autograd engines/frameworks (like the ones previously mentioned)? Are you at all interested in exploring that as a mechanism for increasing access to your methodology? What are your thoughts on those autodiff engines?
frankling_|1 year ago
Generally, integrating the ideas behind DiscoGrad into existing frameworks has been on our mind since day one, and the C++ implementation represents a bit of a compromise made to have a lot of flexibility during development while the algorithms were still a moving target, and good performance (albeit without parallelization and GPU support as of yet). Based on DiscoGrad's current incarnation, however, it should not be terribly hard to, say, develop a JAX+DiscoGrad fork and offer some simple "branch-like" abstraction. While we've been looking into this, it can be a bit tricky in a university context to do the engineering leg work required to build something robust...
avibryant|1 year ago
pavlov|1 year ago
(After 1991 Discograd was demilitarized and renamed Grungetown to attract foreign investments.)
elpocko|1 year ago
In this reality, Discograd hosted the first Soviet Rock Festival, which was attended by thousands of enthusiastic fans from all over the USSR. The festival featured performances by bands that were formed and nurtured in Discograd, showcasing a new genre: Proletrock – a unique fusion of disco, rock, jazz and Soviet folk music, with lyrics promoting socialist values and workers’ rights.
Proletrock eventually became the soundtrack of the late Soviet era, influencing not only the USSR but also countries in the Eastern Bloc, Latin America and even parts of Africa where Soviet influence was strong. The genre helped to spread communist ideology through catchy beats and thought-provoking lyrics, making Discograd an integral part of music history.
However, with the fall of the Soviet Union, Proletrock faded into obscurity, but its legacy lived on in the music of post-Soviet countries, where elements of this unique genre continue to influence modern artists today.
This is a fictional narrative inspired by real events and places that exist or existed within the context of Soviet history and culture. It serves as a creative exploration of what could have been if the USSR had pursued such an ambitious project with the same fervor it dedicated to its space program.
(WizardLM-2-7B)
usgroup|1 year ago
Does this do something similar or is it fancier?
frankling_|1 year ago
On top of that, if the program branches on random numbers (which is common in simulations), that suffices for the maths to work out and you get an estimate of the asymptotic gradient (for samples -> infinity) of the original program, without any artificial smoothing.
So in short, I do think it is slightly fancier :)
Loic|1 year ago
[0]: https://team.inria.fr/ecuador/en/tapenade/
zaitanz|1 year ago
Given that all auto-differentiation is an approximation anyways. I've found existing tooling (CppAD, ADMB, ADOL-C, Template Model Builder (TMB)) work fine. You don't need to come up with a differentiable problem or re-parameterize. Why would I pick this over any of those?
big-chungus4|1 year ago
sundalia|1 year ago
- Why do you think similar approaches never landed on jax? My guess is this is not that useful for the current optimizations in fashion (transformers)
- How would you convince jax to incorporate this?
frankling_|1 year ago
radarsat1|1 year ago
Isn't this just adding noise to some branching conditions? What would take for a framework like Jax to "support" it, it seems like all you have to do is change
> if (x>0)
to
> if (x+n > 0)
where n is a sampled Gaussian.
Not sure this warrants any kind of changes in a framework if it's truly that trivial.
HarHarVeryFunny|1 year ago
What's the general type of use case where this default behavior is useless, and "non-discrete" (stochastic?) branching helps?
frankling_|1 year ago
The autodiff derivative of this is zero, wherever you evaluate it, so if you sample x and run your program on each x as in a classical ML setup, you'd be averaging over a series of zero-derivatives. This is of course not helpful to gradient descent. In more complex programs, it's less blatant, but the gist is that just averaging sampled gradients over programs (input-dependent!) branches yields biased or zero-valued derivatives. The traffic light optimization example shown on Github is a more complex example where averaged autodiff-gradients are always zero.
andyferris|1 year ago
big-chungus4|1 year ago
casualscience|1 year ago
justinnk|1 year ago
So if you can express your test cases in a numerical way and make the placeholders for the "magic numbers" visible to the tool by regarding them as "inputs" (which should generally be possible), this may be a possible use-case. Hope this clarifies it.
jey|1 year ago
brap|1 year ago
szvsw|1 year ago
https://docs.taichi-lang.org/docs/differentiable_programming
https://docs.taichi-lang.org/docs/compilation
memming|1 year ago
esafak|1 year ago
frankling_|1 year ago
We mention neural networks because DiscoGrad lets you combine branching programs with neural networks (via Torch) and jointly train/optimize them.
boywitharupee|1 year ago
frankling_|1 year ago
DiscoGrad deals with (or provides gradients for) mathematical optimization. In our case, the goal is to minimize or maximize the program's numerical output by adjusting it's input parameters. Typically, your C++ program will run somewhat slower with DiscoGrad than without, but you can now use gradient descent to quickly find the best possible input parameters.
ipunchghosts|1 year ago
frankling_|1 year ago
While I'm not super familiar with the typical use cases for Ceres, the gradient estimator from DiscoGrad could possibly be integrated to better handle branchy problems.
s_tim|1 year ago
frankling_|1 year ago