Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import static org.junit.Assert.fail;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
@Category({ AIClusterIT.class })
public class AINodeModelManageIT {

@BeforeClass
Expand All @@ -72,8 +72,8 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
// Test transformers model (chronos2) in tree.
AINodeTestUtils.FakeModelInfo modelInfo =
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
AINodeTestUtils.FakeModelInfo modelInfo = new FakeModelInfo("user_chronos", "custom_t5", "user_defined",
"active");
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
callInferenceTest(statement, modelInfo);
dropUserDefinedModel(statement, modelInfo.getModelId());
Expand All @@ -95,8 +95,8 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
// Test transformers model (chronos2) in table.
AINodeTestUtils.FakeModelInfo modelInfo =
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
AINodeTestUtils.FakeModelInfo modelInfo = new FakeModelInfo("user_chronos", "custom_t5", "user_defined",
"active");
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
forecastTableFunctionTest(statement, modelInfo);
dropUserDefinedModel(statement, modelInfo.getModelId());
Expand Down Expand Up @@ -197,7 +197,7 @@ public void showBuiltInModelTestInTree() throws SQLException {
@Test
public void showBuiltInModelTestInTable() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement(); ) {
Statement statement = connection.createStatement();) {
showBuiltInModelTest(statement);
}
}
Expand All @@ -209,13 +209,16 @@ private void showBuiltInModelTest(Statement statement) throws SQLException {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
while (resultSet.next()) {
String id = resultSet.getString(1);
if ("patchtst_fm".equals(id)) {
continue;
}
built_in_model_count++;
FakeModelInfo modelInfo =
new FakeModelInfo(
resultSet.getString(1),
resultSet.getString(2),
resultSet.getString(3),
resultSet.getString(4));
FakeModelInfo modelInfo = new FakeModelInfo(
resultSet.getString(1),
resultSet.getString(2),
resultSet.getString(3),
resultSet.getString(4));
assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId()));
assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,46 +49,44 @@

public class AINodeTestUtils {

public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP =
Stream.of(
new AbstractMap.SimpleEntry<>(
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"toto", new FakeModelInfo("toto", "toto", "builtin", "active")))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP = Stream.of(
new AbstractMap.SimpleEntry<>(
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"toto", new FakeModelInfo("toto", "toto", "builtin", "active")))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;

static {
Map<String, FakeModelInfo> tmp =
Stream.of(
new AbstractMap.SimpleEntry<>(
"arima", new FakeModelInfo("arima", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"exponential_smoothing",
new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"naive_forecaster",
new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"stl_forecaster",
new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"gaussian_hmm",
new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"stray", new FakeModelInfo("stray", "sktime", "builtin", "active")))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Map<String, FakeModelInfo> tmp = Stream.of(
new AbstractMap.SimpleEntry<>(
"arima", new FakeModelInfo("arima", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"exponential_smoothing",
new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"naive_forecaster",
new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"stl_forecaster",
new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"gaussian_hmm",
new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")),
new AbstractMap.SimpleEntry<>(
"stray", new FakeModelInfo("stray", "sktime", "builtin", "active")))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
tmp.putAll(BUILTIN_LTSM_MAP);
BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp);
}
Expand Down Expand Up @@ -117,36 +115,35 @@ public static void concurrentInference(
AtomicBoolean allPass = new AtomicBoolean(true);
Thread[] threads = new Thread[threadCnt];
for (int i = 0; i < threadCnt; i++) {
threads[i] =
new Thread(
() -> {
try {
for (int j = 0; j < loop; j++) {
try (ResultSet resultSet = statement.executeQuery(sql)) {
int outputCnt = 0;
while (resultSet.next()) {
outputCnt++;
}
if (expectedOutputLength != outputCnt) {
allPass.set(false);
fail(
"Output count mismatch for SQL: "
+ sql
+ ". Expected: "
+ expectedOutputLength
+ ", but got: "
+ outputCnt);
}
} catch (SQLException e) {
allPass.set(false);
fail(e.getMessage());
}
threads[i] = new Thread(
() -> {
try {
for (int j = 0; j < loop; j++) {
try (ResultSet resultSet = statement.executeQuery(sql)) {
int outputCnt = 0;
while (resultSet.next()) {
outputCnt++;
}
} catch (Exception e) {
if (expectedOutputLength != outputCnt) {
allPass.set(false);
fail(
"Output count mismatch for SQL: "
+ sql
+ ". Expected: "
+ expectedOutputLength
+ ", but got: "
+ outputCnt);
}
} catch (SQLException e) {
allPass.set(false);
fail(e.getMessage());
}
});
}
} catch (Exception e) {
allPass.set(false);
fail(e.getMessage());
}
});
threads[i].start();
}
for (Thread thread : threads) {
Expand All @@ -164,8 +161,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model
LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices);
for (int retry = 0; retry < 200; retry++) {
Set<String> foundDevices = new HashSet<>();
try (final ResultSet resultSet =
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
while (resultSet.next()) {
String deviceId = resultSet.getString("DeviceId");
String loadedModelId = resultSet.getString("ModelId");
Expand Down Expand Up @@ -193,8 +189,7 @@ public static void checkModelNotOnSpecifiedDevice(
LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices);
for (int retry = 0; retry < 50; retry++) {
Set<String> foundDevices = new HashSet<>();
try (final ResultSet resultSet =
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
while (resultSet.next()) {
String deviceId = resultSet.getString("DeviceId");
String loadedModelId = resultSet.getString("ModelId");
Expand All @@ -215,16 +210,18 @@ public static void checkModelNotOnSpecifiedDevice(
fail("Model " + modelId + " is still loaded on device " + device);
}

private static final String[] WRITE_SQL_IN_TREE =
new String[] {
"CREATE DATABASE root.AI",
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
};
private static final String[] WRITE_SQL_IN_TREE = new String[] {
"CREATE DATABASE root.AI",
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
};

/** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */
/**
* Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of
* data in tree.
*/
public static void prepareDataInTree() throws SQLException {
prepareData(WRITE_SQL_IN_TREE);
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Expand All @@ -238,7 +235,10 @@ public static void prepareDataInTree() throws SQLException {
}
}

/** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */
/**
* Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data
* in table.
*/
public static void prepareDataInTable() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
Expand Down
14 changes: 14 additions & 0 deletions iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __repr__(self):
},
transformers_registered=True,
),
<<<<<<< HEAD
"toto": ModelInfo(
model_id="toto",
category=ModelCategory.BUILTIN,
Expand All @@ -172,5 +173,18 @@ def __repr__(self):
"AutoModelForCausalLM": "modeling_toto.TotoForPrediction",
},
transformers_registered=True,
=======
"patchtst_fm": ModelInfo(
model_id = "patchtst_fm",
category=ModelCategory.BUILTIN,
state=ModelStates.INACTIVE,
model_type="patchtst_fm",
pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline",
repo_id="ibm-research/patchtst-fm-r1",
auto_map={
"AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig",
"AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction",
},
>>>>>>> d54f8bc19d (feat(AINode): [Issue-17301] Import PatchTST-FM-R1 architecture and register in model_info)
),
}
Empty file.
Loading