sklearn 笔记:neighbors.NearestNeighbors 自定义metric

news/2024/7/9 8:34:28 标签: sklearn, 笔记, 人工智能

1 数据

假设我们有这样的一个数据tst_lst,表示的是5条轨迹的墨卡托坐标,我们希望算出逐点的曼哈顿距离之和,作为两条轨迹的距离

[array([[11549759.51313693,   148744.89246911],
        [11549751.49813359,   148732.97804463],
        [11549757.62070558,   148738.21148336],
        [11549877.73443613,   148886.64075531],
        [11549855.1365795 ,   148900.67083319]]),
 array([[11556428.51911408,   145454.58226351],
        [11557035.91165162,   145493.83259114],
        [11557310.50343952,   145408.66217089],
        [11557748.16714946,   145339.9824732 ],
        [11558124.96136184,   145498.27539452]]),
 array([[11560299.60987809,   143642.48133694],
        [11560236.88134503,   143437.08940241],
        [11560254.26944949,   143331.75455279],
        [11560222.79942945,   143349.26953089],
        [11560224.0350758 ,   143354.70329418]]),
 array([[11559757.30584681,   143885.2194761 ],
        [11560304.02926187,   143639.87580025],
        [11560743.21804884,   143750.12120076],
        [11560626.52182665,   144103.28312704],
        [11560722.44583186,   144272.53199179]]),
 array([[11569978.06036478,   151723.38135785],
        [11569938.73118869,   151248.5811628 ],
        [11569616.11617246,   150791.67584703],
        [11569571.34347327,   150687.55191842],
        [11569688.57402901,   150674.10077112]])]

2 处理原始数据

2.1 直接喂入的问题

如果直接将上面的数据fit入NearestNeighbors,是会报错的:

from sklearn.neighbors import NearestNeighbors

cellKDtree=NearestNeighbors().fit(tst_lst)
cellKDtree
'''
ValueError: Found array with dim 3. NearestNeighbors expected <= 2.
'''

ValueError 是由于尝试在 NearestNeighbors 对象上使用三维数组导致的。NearestNeighbors 期望的输入是一个二维数组,其中每行代表一个数据点,每列代表一个特征

2.2 修改数据形状

每一个轨迹二维矩阵转化成一个一维向量

tst_lst=np.array(tst_lst)
tst_lst_new=[]

for i in range(len(tst_lst)):
    tst_lst_new.append(np.hstack(tst_lst[i]).tolist())
tst_lst_new

'''
[[11549759.513136925,
  148744.89246911363,
  11549751.49813359,
  148732.97804463338,
  11549757.620705582,
  148738.2114833576,
  11549877.734436132,
  148886.6407553058,
  11549855.136579504,
  148900.67083319122],
 [11556428.519114085,
  145454.58226351053,
  11557035.911651615,
  145493.83259113596,
  11557310.503439516,
  145408.66217089174,
  11557748.167149458,
  145339.9824731981,
  11558124.961361844,
  145498.2753945235],
 [11560299.609878086,
  143642.48133694328,
  11560236.881345032,
  143437.0894024146,
  11560254.269449493,
  143331.75455278732,
  11560222.79942945,
  143349.26953088713,
  11560224.035075797,
  143354.7032941798],
 [11559757.305846812,
  143885.21947610297,
  11560304.02926187,
  143639.8758002481,
  11560743.218048835,
  143750.12120075937,
  11560626.521826653,
  144103.28312704086,
  11560722.445831856,
  144272.53199179273],
 [11569978.060364777,
  151723.38135785353,
  11569938.731188687,
  151248.58116280191,
  11569616.116172463,
  150791.67584703089,
  11569571.343473272,
  150687.55191841844,
  11569688.57402901,
  150674.1007711226]]
'''

此时送入NearestNeighbor已经可以了

from sklearn.neighbors import NearestNeighbors

cellKDtree=NearestNeighbors().fit(tst_lst_new)
cellKDtree

3 自定义函数

from scipy.spatial.distance import *
import numpy as np
def disfunc(x,y):
    #每次比较fit入Nearest Neighbor 的矩阵的两行

    x_points=np.array([(x[i],x[i+1]) for i in range(0,len(x),2)])
    y_points=np.array([(y[i],y[i+1]) for i in range(0,len(y),2)])
    #提取经纬度,将每一行一维向量改成二维矩阵

    return float(np.sum(np.diag(cdist(x_points,y_points,metric='cityblock'))))
    '''
    cdist(x_points,y_points,metric='cityblock') 将得到一个二维矩阵,表示x每一个元素和y每一个元素的曼哈顿距离
    np.diag是取二维矩阵的对角元素,表示x和y对应位置元素的距离
    求和就是两条轨迹的距离
    '''

4 使用NearestNeighbor

注:似乎algorithm只能选择默认的brute,KD_tree和ball_tree都不行

from sklearn.neighbors import *

cellKDtree=NearestNeighbors(metric=disfunc).fit(tst_lst_new)
cellKDtree


http://www.niftyadmin.cn/n/5239087.html

相关文章

头歌—密码学基础

第1关&#xff1a;哈希函数 题目 任务描述 本关任务&#xff1a;利用哈希算法统计每个字符串出现的个数。 相关知识 为了完成本关任务&#xff0c;你需要掌握&#xff1a;1.密码学哈希函数的概念及特性&#xff0c;2.安全哈希算法。 密码学哈希函数的概念及特性 我们需要…

ElasticSearch之Force merge API

使用本方法&#xff0c;可以触发强制合并操作。 默认情况下&#xff0c;ElasticSearch会在后台周期性触发合并操作&#xff0c;因此不需要用户刻意使用本方法。 使用强制合并的弊端&#xff1a; 可能会产生大于5G的segment对象&#xff0c;而ElasticSearch后台自动触发的合并…

操作系统·设备管理

I/O系统是计算机系统的重要组成部分&#xff0c;是OS中最复杂且与硬件密切相关的部分 I/O系统的基本任务是完成用户提出的I/O请求&#xff0c;提高I/O速率以及改善I/O设备的利用率&#xff0c;方便高层进程对IO设备的使用 I/O系统包括用于实现信息输入、输出和存储功能的设备和…

leetcode做题笔记1038. 从二叉搜索树到更大和树

给定一个二叉搜索树 root (BST)&#xff0c;请将它的每个节点的值替换成树中大于或者等于该节点值的所有节点值之和。 提醒一下&#xff0c; 二叉搜索树 满足下列约束条件&#xff1a; 节点的左子树仅包含键 小于 节点键的节点。节点的右子树仅包含键 大于 节点键的节点。左右…

配置BFD多跳检测示例

BFD简介 定义 双向转发检测BFD&#xff08;Bidirectional Forwarding Detection&#xff09;是一种全网统一的检测机制&#xff0c;用于快速检测、监控网络中链路或者IP路由的转发连通状况。 目的 为了减小设备故障对业务的影响&#xff0c;提高网络的可靠性&#xff0c;网…

云原生周刊:K8s 的 YAML 技巧 | 2023.12.4

开源项目推荐 Helmfile Helmfile 是用于部署 Helm Chart 的声明性规范。其功能有&#xff1a; 保留图表值文件的目录并维护版本控制中的更改。将 CI/CD 应用于配置更改。定期同步以避免环境偏差。 Docketeer 一款 Docker 和 Kubernetes 开发人员工具&#xff0c;用于管理容…

LeetCode103. Binary Tree Zigzag Level Order Traversal

文章目录 一、题目二、题解 一、题目 Given the root of a binary tree, return the zigzag level order traversal of its nodes’ values. (i.e., from left to right, then right to left for the next level and alternate between). Example 1: Input: root [3,9,20,n…

「Verilog学习笔记」时钟分频(偶数)

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 timescale 1ns/1nsmodule even_div(input wire rst ,input wire clk_in,output wire clk_out2,output wire clk_out4,output wire clk_out8); //********…