import matplotlib.pyplot as plt
from tenpy.models import lattice
plt.figure(figsize=(7., 2.))
ax = plt.gca()
lat = lattice.Ladder(4, None, bc='periodic')
for key, lw in zip(['rung_NN', 'leg_NN', 'diagonal'],
                    [3., 2., 1.],):
    pairs = lat.pairs[key]
    lat.plot_coupling(ax, pairs, linestyle='--', linewidth=lw, label=key)
ax.plot([], [], ' ', label='nearest_neighbors =\n  rung_NN + leg_NN')
lat.plot_sites(ax)
lat.plot_basis(ax, origin=[-0.5, -0.25], shade=False)
ax.set_aspect('equal')
ax.set_xlim(-1.)
ax.set_ylim(-0.5, 1.5)
ax.legend(loc='upper left', bbox_to_anchor=(1., 1.))
plt.show()