Cosmos is a framework for object-centric world modeling that is designed for compositional generalization (CG), i.e., high performance on unseen input scenes obtained through the composition of known visual "atoms."
The central insight behind Cosmos is the use of a novel form of neurosymbolic grounding. Specifically, the framework introduces two new tools: (i) neurosymbolic scene encodings, which represent each entity in a scene using a real vector computed using a neural encoder, as well as a vector of composable symbols describing attributes of the entity, and (ii)a neurosymbolic attention mechanism that binds these entities to learned rules of interaction.
Cosmos is end-to-end differentiable; also, unlike traditional neurosymbolic methods that require representations to be manually mapped to symbols, it computes an entity's symbolic attributes using vision-language foundation models. Through an evaluation that considers two different forms of CG on an established blocks-pushing domain, we show that the framework establishes a new state-of-the-art for CG in world modeling.
The next sections might not render correctly on mobile devices. Please view this page on a desktop or enable "desktop mode" for the best experience!
Our dataset contains 2D shapes of various shapes and colors. An image is generated by sampling a fixed number of shapes. We refer to each shape as an atom. Atoms can be composed together in many ways. We shall study two types of compositions in this work.
Entity composition involves sampling atoms (entities) such that the specific combination of atoms is not seen during training.
Specifically, here, the model has seen a red square, a purple triangle, and a blue circle in other contexts, but never together.
The rendered image randomly places the atoms in the scene.
In this environment, each atom can be moved either North, East, South, or West. Actions are sampled uniformly.
The next state is derived by applying the action to the current state. Each atom also has a pre-specified "weight" which influences whether it can push other shapes (if heavier) or be pushed (if lighter). The weights can be inferred from the shape. The white arrow provides emphasis and isn't part of the state.
Some Observations:
Entity composition has been the traditional way of studying compositional generalization. However, in EC, the movement of each object is independent of the movement of other objects. In relational composition, we compose on the level of relations between objects. Now, in addition to sampling new objects, objects which share a certain attribute have shared dynamics as well. In this case, the attribute is color. A composition of objects occurs when the two objects share the same color. Furthermore, objects with shared attributes share dynamics as well.
Specifically, here, the square and the circle are both red, while the triangle is green. Like before, the model has seen each shape in other contexts, but never together. ie: a red square and a red circle have never been seen together.
There are many ways in which we can select the relation to compose on. In this work, we study two cases.
An action is applied on a single object and the dynamics are shared with all objects of the same color that are adjacent to the object.
Consecutively, the action on the square moves both the square and the circle northwards.
An action is applied on a single object and the dynamics are shared with all objects of the same color.
This time, the action on the square also moves the circle northwards. However, the circle and the square are considerably far apart.
The input to the model is an image of the current state, and the object-factorized action on which to condition the next state.
We process the image into a set of entities using pretrained SAM to obtain segmentations for each entity and a finetuned Resnet to obtain a latent vector for each entity.
Each vector is decoded to an image using a spatial decoder. The spatial decoder is trained in conjunction with the entity encoder. In practice, we warm-start the encoder and decoder to ensure good auto-encoder reconstructions.
We use a pretrained CLIP model to predict the
symbolic attribute that most-likely describes the entity. Notice that the attribute labelled here (C_shape = ⚫️
)
ignores all other attributes of the entity.
The process is repeated for other attributes as well. Notice that each attribute is labelled independently of the others, allowing the model to trivially generalize to different compositions of attributes.
Each label is represented as a one-hot vector. In practice, this discrete representation does not align well with downstream attention based modules. Hence, the one-hot vector is used to select a learnable vector from a set of learnable vectors.
Thus, the resultant symbol vector is a composition of learnable latent vectors distinct to each attribute value. Furthermore, we can ensure a canonical ordering of the symbols, making downstream attention-based computations invariant to permutations of attributes
The symbolic labelling process is repeated for each entity in the scene. The resultant symbolic vectors are stacked to form a symbolic encoding of the scene.
Following NPS, we will break down the transition function into two parts: Learning to select a module and learning to specialize the module to a task. However, how should we employ the neural and symbolic encodings?
The symbolic encoding will help the selection module be robust to attribute compositions. However, if we just use the symbolic encoding, we will risk bottlenecking the model's ability to learn fine-grained dynamics-relevant attributes that may not be known ahead of time.
The neural encoding, on the other hand, captures rich dynamics-relevant attributes which will enable good reconstruction. However, we will risk overfitting to attribute compositions seen during training.
We solve this problem by employing a hybrid approach. We'll use the symbolic encoding to select a module and the neural encoding to predict the next state.
We'll now describe the rest of the architecture.
The symbolic encoding is concatenated with the action vector after the encoding and action are reordered to match the canonical ordering of the symbols. The concatenated vector is select a learnable module which is used to predict the next state.
Note that the symbolic encoding is only selecting the module, which leaves room for the neural encoding to learn fine-grained dynamics-relevant attributes that may not be known ahead of time.
The selected module is applied to the neural encoding to predict the next state.
We re-use the spatial decoder to decode the predicted next state into an image.
The model is trained end-to-end using a mixture of the next-state reconstruction error (MSE
) and the
auto-encoder reconstruction error (AE-MSE
; not shown).
We first look at qualitative results in Entity Composition. A comprehensive analysis is available in the paper.
Each row represents a random sample of objects and a randomly sampled action. The model has never seen this composition of objects before.
We expect strong performance from baselines as the dynamics are invariant to the sampled composition.
COSMOS is able to predict the next state with high fidelity.
Our first baseline is a modified version of NPS with a slot-action alignment attention mechanism. This is equivalent to an ablation of COSMOS without the symbolic representation. NPS is also able to predict the next state with good fidelity. Slight deviations are emphasized with a red dotted box.
Our second baseline uses a GNN to model the interactions between objects. This is related to HOWM and G-SWM. It also serves as an ablation of COSMOS without the symbolic representation and module selection mechanism. The GNN achieves strong performance. However, some entities are reconstructed with wrong attributes; an indication of attribute overfitting.
In this environment, objects related by position
(adjacency) and color
will move together. As before, the sampled objects have never been
seen by the model.
As the dynamics are no longer invariant to compositions, we expect strong performance from COSMOS and overfitting from baselines.
COSMOS is able to predict the next state with high fidelity.
Our first baseline mispredicts the next state. Notice how, in both samples, the model moves the actor and the triangle
even though the triangle doesn't share any attributes with the actor.
The GNN also mispredicts the next state. For more details and results, check out our paper!
This project would not be possible without the excellent work of the community. These are some relevant papers to better understand the premise of this work.
@misc{sehgal2023neurosymbolic,
title={Neurosymbolic Grounding for Compositional World Models},
author={Atharva Sehgal and Arya Grayeli and Jennifer J. Sun and Swarat Chaudhuri},
year={2023},
eprint={2310.12690},
archivePrefix={arXiv},
primaryClass={cs.LG}
}