37 lines
1.0 KiB
C#
37 lines
1.0 KiB
C#
namespace RhSolutions.ML.Tests;
|
|
|
|
public abstract class TestBase
|
|
{
|
|
protected static string _appPath = Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]) ?? ".";
|
|
protected static string _dataPath = Path.Combine(_appPath, "..", "..", "..", "..", "Models", "model.zip");
|
|
protected MLContext _mlContext;
|
|
protected PredictionEngine<Product, TypePrediction> _predEngine;
|
|
|
|
public TestBase()
|
|
{
|
|
_mlContext = new MLContext(seed: 0);
|
|
ITransformer loadedModel = _mlContext.Model.Load(_dataPath, out var _);
|
|
_predEngine = _mlContext.Model.CreatePredictionEngine<Product, TypePrediction>(loadedModel);
|
|
}
|
|
|
|
public void Execute(string name, string expectedGroup)
|
|
{
|
|
Product p = new()
|
|
{
|
|
Name = name
|
|
};
|
|
var prediction = _predEngine.Predict(p);
|
|
Assert.That(prediction.Type, Is.EqualTo(expectedGroup));
|
|
}
|
|
|
|
public void Execute(Product expected)
|
|
{
|
|
Product actual = new()
|
|
{
|
|
Name = expected.Name
|
|
};
|
|
var prediction = _predEngine.Predict(actual);
|
|
Assert.That(prediction.Type, Is.EqualTo(expected.Type));
|
|
}
|
|
}
|