JStarCraft AI
希望路过的同学,顺手给JStarCraft框架点个Star,算是对作者的一种鼓励吧!
JStarCraft AI是一个机器学习的轻量级框架.遵循Apache 2.0协议.
在学术界,绝大多数研究人员使用的编程语言是Python.
在工业界,绝大多数开发人员使用的编程语言是Java.
JStarCraft AI是一个基于Java语言的机器学习工具包,由一系列的数据结构,算法和模型组成.
目标是作为在学术界与工业界从事机器学习研发的相关人员之间的桥梁.普及机器学习在Java领域的应用.
作者 | 洪钊桦 |
---|---|
110399057@qq.com, jstarcraft@gmail.com |
JStarCraft AI架构
JStarCraft AI特性
- 1.数据(data)
- 属性与特征
- 连续
- 离散
- 模块与实例
- 选择,排序与切割
- 属性与特征
- 2.环境(environment)
- 串行计算
- 并行计算
- CPU计算
- GPU计算
- 3.数学(math)
- 算法(algorithm)
- 分解
- 概率
- 相似度
- 损失函数
- 数据结构(structure)
- 标量
- 向量
- 矩阵
- 张量
- 单元
- 表单
- 算法(algorithm)
- 4.调制解调(modem)
- 5.模型(model)
- 线性模型(linear)
- 近邻模型(nearest neighbor)
- 矩阵分解模型(matrix factorization)
- 神经网络模型(neutral network)
- 计算图
- 节点
- 层
- 正向传播与反向传播
- 激活函数
- 梯度更新
- 计算图
- 概率图模型(probabilistic graphical)
- 规则模型(rule)
- 支持向量机模型(support vector machine)
- 树模型(tree)
- 6.优化(optimization)
- 梯度下降法(gradient descent)
- 批量梯度下降(batch gradient descent)
- 随机梯度下降(stochastic gradient descent)
- 牛顿法和拟牛顿法(newton method/quasi newton method)
- 共轭梯度法(conjugate gradient)
- 试探法(heuristic)
- 模拟退火算法(simulate anneal)
- 遗传算法(genetic)
- 蚁群算法(ant colony)
- 粒子群算法(particle swarm)
- 梯度下降法(gradient descent)
- 7.有监督学习(supervised)
- 分类
- 回归
- 8.无监督学习(unsupervised)
- 聚类
- 关联
- 9.丰富的评估指标
JStarCraft AI教程
- 1.设置依赖
- 2.配置环境
- 3.使用数据
Maven依赖
<dependency>
<groupId>com.jstarcraft</groupId>
<artifactId>ai</artifactId>
<version>1.0</version>
</dependency>
Gradle依赖
compile group: 'com.jstarcraft', name: 'ai', version: '1.0'
设置CPU环境
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
设置GPU环境
- CUDA 9.0
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-9.0-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
- CUDA 9.1
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-9.1-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
- CUDA 9.2
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-9.2-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
- CUDA 10.0
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.0-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
- CUDA 10.1
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
使用环境上下文
// 获取默认环境上下文
EnvironmentContext context = EnvironmentContext.getContext();
// 在环境上下文中执行任务
Future<?> task = context.doTask(() - > {
int dimension = 10;
MathMatrix leftMatrix = getRandomMatrix(dimension);
MathMatrix rightMatrix = getRandomMatrix(dimension);
MathMatrix dataMatrix = getZeroMatrix(dimension);
dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.PARALLEL);
});
数据表示
- 未处理的形式(转换前)
用户(User) | 旧手机类型(Item) | 新手机类型(Item) | 评分(Score) |
---|---|---|---|
Google Fan | Android | Android | 3 |
Google Fan | Android | IOS | 1 |
Google Fan | IOS | Android | 5 |
Apple Fan | IOS | IOS | 3 |
Apple Fan | Android | IOS | 5 |
Apple Fan | IOS | Android | 1 |
- 已处理的形式(转换后)
定性(User) | 定性(Item) | 定性(Item) | 定量(Score) |
---|---|---|---|
0 | 0 | 0 | 3 |
0 | 0 | 1 | 1 |
0 | 1 | 0 | 5 |
1 | 1 | 1 | 3 |
1 | 0 | 1 | 5 |
1 | 1 | 0 | 1 |
数据转换
数据转换器(DataConverter)负责各种各样的格式转换为JStarCraft AI框架能够处理的数据模块(DataModule).
JStarCraft AI框架各个转换器与其它系统之间的关系:
- 定义数据属性
// 定性属性
Map<String, Class<?>> qualityDifinitions = new HashMap<>();
qualityDifinitions.put("user", String.class);
qualityDifinitions.put("item", String.class);
// 定量属性
Map<String, Class<?>> quantityDifinitions = new HashMap<>();
quantityDifinitions.put("score", float.class);
DataSpace space = new DataSpace(qualityDifinitions, quantityDifinitions);
- 定义数据模块
TreeMap<Integer, String> configuration = new TreeMap<>();
configuration.put(1, "user");
configuration.put(3, "item");
configuration.put(4, "score");
DataModule module = space.makeDenseModule("module", configuration, 1000);
JStarCraft AI框架兼容的格式
- ARFF
// ARFF转换器
ArffConverter converter = new ArffConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取流
File file = new File(this.getClass().getResource("module.arff").toURI());
InputStream stream = new FileInputStream(file);
// 转换数据
int count = converter.convert(module, stream, null, null, null);
- CSV
// CSV转换器
CsvConverter converter = new CsvConverter(',', space.getQualityAttributes(), space.getQuantityAttributes());
// 获取流
File file = new File(this.getClass().getResource("module.csv").toURI());
InputStream stream = new FileInputStream(file);
// 转换数据
int count = converter.convert(module, stream, null, null, null);
- JSON
// JSON转换器
JsonConverter converter = new JsonConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取流
File file = new File(this.getClass().getResource("module.json").toURI());
InputStream stream = new FileInputStream(file);
// 转换数据
int count = converter.convert(module, stream, null, null, null);
- HQL
// HQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取游标
String selectDataHql = "select data.user, data.leftItem, data.rightItem, data.score from MockData data";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataHql);
ScrollableResults iterator = query.scroll();
// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();
- SQL
// SQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取游标
String selectDataSql = "select user, leftItem, rightItem, score from MockData";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataSql);
ScrollableResults iterator = query.scroll();
// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();
数据处理
- 选择
- 排序
- 切割
评估指标
排序指标
- AUC
- Diversity
- MAP
- MRR
- NDCG
- Novelty
- Precision
- Recall
评分指标
- MAE
- MPE
- MSE/RMSE