ggpx_statistics/similarity.py

161 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from io import BytesIO
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
# 全局设置中文字体
rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体为黑体
rcParams['axes.unicode_minus'] = False # 解决坐标轴负号显示问题
# 计算相关系数矩阵
def row_correlations(df):
num_rows = df.shape[0]
corr_matrix = np.zeros((num_rows, num_rows))
# 标准化数据,并处理标准差为零的情况
def standardize_column(x):
std = x.std()
return (x - x.mean()) / std if std != 0 else np.zeros_like(x)
standardized_df = df.apply(standardize_column, axis=0)
for i in range(num_rows):
for j in range(i, num_rows):
r = np.corrcoef(standardized_df.iloc[i], standardized_df.iloc[j])[0, 1]
corr_matrix[i, j] = r
corr_matrix[j, i] = r
corr_df = pd.DataFrame(corr_matrix, index=df.index, columns=df.index)
return corr_df
def row_euclidean_distances(df):
# 获取DataFrame的行数
num_rows = df.shape[0]
# 初始化一个空的对称方阵用于存储欧氏距离
distance_matrix = np.zeros((num_rows, num_rows))
# 计算每两行之间的欧氏距离
for i in range(num_rows):
for j in range(i, num_rows):
# 计算第i行和第j行的欧氏距离
distance = np.linalg.norm(df.iloc[i] - df.iloc[j])
# 存储在对称方阵中
distance_matrix[i, j] = distance
distance_matrix[j, i] = distance
# 将欧氏距离矩阵转换为DataFrame
distance_df = pd.DataFrame(distance_matrix, index=df.index, columns=df.index)
return distance_df
def calculate_similarity(correlation_df, distance_df):
# 初始化一个空的对称方阵用于存储相似度
num_rows = correlation_df.shape[0]
similarity_matrix = np.zeros((num_rows, num_rows))
# 计算相似度S
for i in range(num_rows):
for j in range(i, num_rows):
r = correlation_df.iloc[i, j]
d = distance_df.iloc[i, j]
R = (1 + r) / 2
S = 100 * R / d if d != 0 else np.inf # 如果距离为0相似度设为无穷大
# 存储在对称方阵中
similarity_matrix[i, j] = S
similarity_matrix[j, i] = S
# 将相似度矩阵转换为DataFrame
similarity_df = pd.DataFrame(similarity_matrix, index=correlation_df.index, columns=correlation_df.index)
return similarity_df
# 绘制相似度网络图
def plot_similarity_network(similarity_df, threshold=0.5):
G = nx.Graph()
for i in range(len(similarity_df)):
for j in range(i + 1, len(similarity_df)):
if similarity_df.iloc[i, j] > threshold:
G.add_edge(similarity_df.index[i], similarity_df.columns[j], weight=similarity_df.iloc[i, j])
pos = nx.spring_layout(G)
plt.figure(figsize=(10, 8))
edges = G.edges(data=True)
# 根据相似度调整边的宽度
weights = [e['weight'] for u, v, e in edges]
# 设置颜色为浅色调
nx.draw(G, pos, with_labels=True, node_size=700, node_color='lightblue', font_size=10,
width=2, edge_color=weights, edge_cmap=plt.cm.Blues, edge_vmin=min(weights),
edge_vmax=max(weights)*0.5)
# 显示边的相似度值
edge_labels = {(u, v): f'{e["weight"]:.2f}' for u, v, e in edges}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plt.suptitle("相似度网络图", fontsize=16)
# plt.title('Similarity Network Graph')
# 将图像保存到字节流中
img_bytes = BytesIO()
plt.savefig(img_bytes, format='png')
# 设置字节流的位置到开始
img_bytes.seek(0)
# 输出图像
plt.close()
return img_bytes
# 绘制热力图
def plot_similarity_heatmap(similarity_df):
plt.figure(figsize=(10, 8))
sns.heatmap(similarity_df, annot=True, fmt=".2f", cmap="coolwarm", linewidths=.5)
plt.title("相关系数热力图", fontsize=16)
plt.xlabel('样品编号')
plt.ylabel('样品编号')
# 将图像保存到字节流中
img_bytes = BytesIO()
plt.savefig(img_bytes, format='png')
# 设置字节流的位置到开始
img_bytes.seek(0)
# 输出图像
plt.close()
return img_bytes
# 示例用法
if __name__ == '__main__':
# 示例矩阵,和之前的扇形图要求一样
matrix = pd.read_csv('./radartest.csv', index_col=0)
# 调用函数
correlation_result = row_correlations(matrix)
distance_result = row_euclidean_distances(matrix)
similarity_result = calculate_similarity(correlation_result, distance_result)
# 输出相关系数矩阵
print("每两行之间的相关系数矩阵:")
print(correlation_result)
# 输出欧氏距离矩阵
print("每两行之间的欧氏距离矩阵:")
print(distance_result)
# 输出相似度矩阵
print("每两行之间的相似度矩阵:")
print(similarity_result)
# 绘制相关系数热力图
plot_similarity_heatmap(correlation_result)
# 绘制相似度网络图
plot_similarity_network(similarity_result, threshold=0.5)