import numpy as np
rng = np.random.default_rng(42)
m, n, k, b = 4, 5, 3, 8
A = rng.standard_normal((m, k)) # shape (m, k)
B = rng.standard_normal((k, n)) # shape (k, n)
x = rng.standard_normal((k,)) # shape (k,)
y = rng.standard_normal((m,)) # shape (m,)
u = rng.standard_normal((n,)) # shape (n,)
S = rng.standard_normal((m, m)) # shape (m, m) -- square
BA = rng.standard_normal((b, m, k)) # shape (b, m, k) -- batch of matrices
BB = rng.standard_normal((b, k, n)) # shape (b, k, n)
# --- 1. Dot product: scalar = u . v ---
dot_es = np.einsum("i,i->", x, x)
dot_np = np.dot(x, x)
assert np.isclose(dot_es, dot_np)
# --- 2. Outer product: M[i,j] = u[i]*v[j] ---
outer_es = np.einsum("i,j->ij", y, u) # shape (m, n)
outer_np = np.outer(y, u)
assert np.allclose(outer_es, outer_np)
# --- 3. Matrix-vector: y[i] = A[i,k] * x[k] ---
mv_es = np.einsum("ik,k->i", A, x) # shape (m,)
mv_np = A @ x
assert np.allclose(mv_es, mv_np)
# --- 4. Matrix-matrix: C[i,j] = A[i,k] * B[k,j] ---
mm_es = np.einsum("ik,kj->ij", A, B) # shape (m, n)
mm_np = A @ B
assert np.allclose(mm_es, mm_np)
# --- 5. Transposed multiply: C[i,j] = A[k,i] * B[k,j] (A^T @ B) ---
atb_es = np.einsum("ki,kj->ij", A, A) # shape (k, k)... wait shape (m,m)
# A is (m,k): A^T is (k,m), A^T @ A is (k,k)
AtA_es = np.einsum("ik,il->kl", A, A) # shape (k, k)
AtA_np = A.T @ A
assert np.allclose(AtA_es, AtA_np)
# --- 6. Trace: scalar = A[i,i] ---
trace_es = np.einsum("ii->", S)
trace_np = np.trace(S)
assert np.isclose(trace_es, trace_np)
# --- 7. Element-wise (Hadamard): C[i,j] = A[i,j] * A[i,j] ---
S2 = rng.standard_normal((m, k))
had_es = np.einsum("ij,ij->ij", A, S2)
had_np = A * S2
assert np.allclose(had_es, had_np)
# --- 8. Row sum: v[i] = sum_j A[i,j] ---
rowsum_es = np.einsum("ij->i", A) # shape (m,)
rowsum_np = A.sum(axis=1)
assert np.allclose(rowsum_es, rowsum_np)
# --- 9. Column sum: v[j] = sum_i A[i,j] ---
colsum_es = np.einsum("ij->j", A) # shape (k,)
colsum_np = A.sum(axis=0)
assert np.allclose(colsum_es, colsum_np)
# --- 10. Transpose: B[j,i] = A[i,j] ---
T_es = np.einsum("ij->ji", A) # shape (k, m)
T_np = A.T
assert np.allclose(T_es, T_np)
# --- 11. Batched matrix multiply: C[b,i,j] = sum_k A[b,i,k] * B[b,k,j] ---
bmm_es = np.einsum("bik,bkj->bij", BA, BB) # shape (b, m, n)
bmm_np = BA @ BB
assert np.allclose(bmm_es, bmm_np)
# --- 12. Quadratic form: scalar = x^T S x ---
S_sq = S @ S.T # make PSD, shape (m,m)
xS = rng.standard_normal((m,))
qf_es = np.einsum("i,ij,j->", xS, S_sq, xS)
qf_np = xS @ S_sq @ xS
assert np.isclose(qf_es, qf_np)
print("All 12 einsum patterns verified.")