The AiEdge Newsletter

Share this post

Deep Dive: How I taught ChatGPT to Draw Diagrams with LangChain

newsletter.theaiedge.io
Newsletters

Deep Dive: How I taught ChatGPT to Draw Diagrams with LangChain

Building Machine Learning Solutions

Damien Benveniste
Jun 15, 2023
∙ Paid
18
Share this post

Deep Dive: How I taught ChatGPT to Draw Diagrams with LangChain

newsletter.theaiedge.io
Share

I have been working my way through generating automated Machine Learning content using ChatGPT and LangChain. I still have a long way to go but let me share my experience attempting to generate explanatory diagrams using those tools. We cover:

  • Drawing diagrams with Mermaid

  • Teaching ChatGPT to draw diagrams

    • Scraping the Mermaid documentation

    • Loading the data into a vector database

    • Generating diagrams with ChatGPT

  • Explaining an article with diagrams

    • Getting the article

    • Extracting the concepts to explain

    • Explaining concepts

    • Describing diagrams

    • Translating into Mermaid code

  • Example of an article explained with ChatGPT

    • Transformer

    • Self-attention

    • Attention mechanisms

    • Sequence transduction models

    • Parallelizable models


Drawing diagrams with Mermaid

I wanted to find an easy charting tool that ChatGPT could use to visually explain complex concepts. Mermaid is a simple Javascript library:

It uses simple syntax to generate diagrams. For example the following code:

graph LR;
    A--> B & C & D;
    B--> A & E;
    C--> A & E;
    D--> A & E;
    E--> B & C & D; 

generates this diagram:

To render those diagrams, you can use these online editors:

  • Mermaid editor 1

  • Mermaid editor 2

Additionally you can visualize those images in Python using the following script

import base64
from IPython.display import Image, display
import matplotlib.pyplot as plt

def visualize(graph):
    graphbytes = graph.encode('ascii')
    base64_bytes = base64.b64encode(graphbytes)
    base64_string = base64_bytes.decode('ascii')
    display(Image(url='https://mermaid.ink/img/' + base64_string))
    
visualize("""
graph LR;
    A--> B & C & D;
    B--> A & E;
    C--> A & E;
    D--> A & E;
    E--> B & C & D;
""")

Teaching ChatGPT to draw diagrams

Scraping the Mermaid documentation

To be honest, ChatGPT was already trained on a previous version of the Mermaid documentation, so it could already generate Mermaid code, but I wanted to showcase how I would do it if it didn't. Here, I will scrape the Mermaid documentation to make it available to ChatGPT to help it generate diagrams. The main goal is to demonstrate LangChain’s scraping capability, but I found it helped the LLM in making less mistakes in generating the code and it could help in the case the syntax evolved since the LLM got trained.

To scrape the Mermaid website we use Apify. After signing up, you can get your API key, by clicking the “API“ button in the settings section

Let’s make sure to capture the Apify API key as an environment variable

import os
os.environ['APIFY_API_TOKEN'] = ...

Scraping the whole website with the Apify wrapper in LangChain is quite easy. We just need to run the following code

from langchain.document_loaders.base import Document
from langchain.utilities import ApifyWrapper
from langchain.indexes import VectorstoreIndexCreator

apify = ApifyWrapper()

url = 'https://mermaid.js.org/'

loader = apify.call_actor(
    actor_id='apify/website-content-crawler',
    run_input={'startUrls': [{'url': url}]},
    dataset_mapping_function=lambda item: Document(
        page_content=item['text'] or '', 
        metadata={'source': item['url']}
    ),
)

Here we used the 'apify/website-content-crawler' actor to automatically scrape the whole website recursively starting from the URL https://mermaid.js.org/. You can find more information about this scraper in the Apify console

You can find the scraping results in the “runs” section

Loading the data into a vector database

We now index that data into a local vector database to make it searchable. We use the OpenAI embedding text encoding to convert text data into vectors

Before continuing, make sure to get your OpenAI API key by signing up on the OpenAI platform and capturing the API key as an environment variable

os.environ['OPENAI_API_KEY'] = ...

To move the data into a vector database, we simply run

from langchain.indexes import VectorstoreIndexCreator

index = VectorstoreIndexCreator().from_loaders([loader])

The current default vector store is ChromaDB. We can now query the database

query = 'What is the syntax for flowcharts?'
result = index.query_with_sources(query)
result

> 'question': 'What is the syntax for flowcharts?'

'answer': ' The syntax for flowcharts includes nodes (geometric shapes) and edges (arrows or lines). Special characters can be escaped using quotes or entity codes, and subgraphs can be defined using the "subgraph" keyword.'

'sources': 'https://mermaid.js.org/syntax/flowchart.html'

Generating diagrams with ChatGPT

Now that we have the data, let's ask ChatGPT to generate some diagrams. We can turn the index into a retriever for the LLM to use

retriever = index.vectorstore.as_retriever()
# we change the number of document to return 
retriever.search_kwargs['k'] = 10

‘k’ here is the number of documents returned when the LLM queries the database. We can create a chain to augment ChatGPT with that database

from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI

llm = ChatOpenAI()

mermaid_qa = RetrievalQA.from_chain_type(
    llm=llm, 
    retriever=retriever,
)

Let’s see if ChatGPT can transform the following text (from Wikipedia) into a diagram:

Machine learning (ML) is a field devoted to understanding and building methods that let machines "learn" – that is, methods that leverage data to improve computer performance on some set of tasks.[1] Machine learning algorithms build a model based on sample data, known as training data, in order to make predictions or decisions without being explicitly programmed to do so.[2] Machine learning algorithms are used in a wide variety of applications, such as in medicine, email filtering, speech recognition, agriculture, and computer vision, where it is difficult or unfeasible to develop conventional algorithms to perform the needed tasks.[3][4] A subset of machine learning is closely related to computational statistics, which focuses on making predictions using computers, but not all machine learning is statistical learning. The study of mathematical optimization delivers methods, theory and application domains to the field of machine learning. Data mining is a related field of study, focusing on exploratory data analysis through unsupervised learning.[6][7] Some implementations of machine learning use data and neural networks in a way that mimics the working of a biological brain.[8][9] In its application across business problems, machine learning is also referred to as predictive analytics.

query = """
Your job is to write the code to generate a colorful mermaid diagram describing the following text. 
Return only the code and make sure it has multiple colors

TEXT: {text}
"""
result = mermaid_qa.run(query.format(text=text))

Here is the resulting code

graph LR
A[Machine learning] --> B(Building methods)
B --> C(Leverage data)
C --> D(Improve computer performance)
D --> E[Set of tasks]

A --> F[Training data]
F --> G[Build model]
G --> H[Predictions/Decisions]

A --> I[Wide variety of applications]
I --> J[Medicine]
I --> K[Email filtering]
I --> L[Speech recognition]
I --> M[Agriculture]
I --> N[Computer vision]

A --> O[Subset of machine learning]
O --> P[Computational statistics]
O --> Q[Mathematical optimization]
O --> R[Data mining]
R --> S[Unsupervised learning]

A --> T[Neural networks]
T --> U[Biological brain]

A --> V[Predictive analytics]

style A fill:#F5DEB3,stroke:#A0522D,stroke-width:4px; 
style E fill:#F5DEB3,stroke:#A0522D,stroke-width:4px;
style H fill:#F5DEB3,stroke:#A0522D,stroke-width:4px;
style J fill:#7FFFD4,stroke:#4B0082,stroke-width:4px;
style K fill:#7FFFD4,stroke:#4B0082,stroke-width:4px;
style L fill:#7FFFD4,stroke:#4B0082,stroke-width:4px;
style M fill:#7FFFD4,stroke:#4B0082,stroke-width:4px;
style N fill:#7FFFD4,stroke:#4B0082,stroke-width:4px;
style P fill:#ADD8E6,stroke:#00008B,stroke-width:4px;
style Q fill:#ADD8E6,stroke:#00008B,stroke-width:4px;
style S fill:#ADD8E6,stroke:#00008B,stroke-width:4px;
style U fill:#FF69B4,stroke:#C71585,stroke-width:4px;
style V fill:#F5DEB3,stroke:#A0522D,stroke-width:4px;

and the related diagram

Explaining an article with diagrams

Getting the article

I wanted to see if I could use LangChain to automate the process of explaining articles. Let’s see if we can explain the famous article “Attention is all you need!“

Let’s use the ArXiv Python package (`pip install arxiv`) to download the paper

Keep reading with a 7-day free trial

Subscribe to

The AiEdge Newsletter
to keep reading this post and get 7 days of free access to the full post archives.

Already a paid subscriber? Sign in
Previous
Next
© 2023 AiEdge
Privacy ∙ Terms ∙ Collection notice
Start WritingGet the app
Substack is the home for great writing