Machine Learning: Science and Technology (Jan 2024)
Concept graph embedding models for enhanced accuracy and interpretability
Abstract
In fields requiring high accountability, it is necessary to understand how deep-learning models make decisions when analyzing the causes of image classification. Concept-based interpretation methods have recently been introduced to reveal the internal mechanisms of deep learning models using high-level concepts. However, such methods are constrained by a trade-off between accuracy and interpretability. For instance, in real-world environments, unlike in well-curated training data, the accurate prediction of expected concepts becomes a challenge owing to the various distortions and complexities introduced by different objects. To overcome this tradeoff, we propose concept graph embedding models (CGEM), reflecting the complex dependencies and structures among concepts through the learning of mutual directionalities. The concept graph convolutional neural network (Concept GCN), a downstream task of CGEM, differs from previous methods that solely determine the presence of concepts because it performs a final classification based on the relationships between con- cepts learned through graph embedding. This process endows the model with high resilience even in the presence of incorrect concepts. In addition, we utilize a deformable bipartite GCN for object- centric concept encoding in the earlier stages, which enhances the homogeneity of the concepts. The experimental results show that, based on deformable concept encoding, the CGEM mitigates the trade-off between task accuracy and interpretability. Moreover, it was confirmed that this approach allows the model to increase the resilience and interpretability while maintaining robustness against various real-world concept distortions and incorrect concept interventions. Our code is available at https://github.com/jumpsnack/cgem .
Keywords