0
0

ML.Net query modifier implementation

This commit is contained in:
Serghei Cebotari 2023-09-19 14:56:55 +03:00
parent a8ec0cee16
commit 12b557f53d
10 changed files with 168 additions and 1 deletions

1
.gitignore vendored
View File

@ -452,3 +452,4 @@ $RECYCLE.BIN/
!.vscode/tasks.json !.vscode/tasks.json
!.vscode/launch.json !.vscode/launch.json
!.vscode/extensions.json !.vscode/extensions.json
/RhSolutions.Api/MLModels

View File

@ -0,0 +1,30 @@
using Microsoft.AspNetCore.Http.Extensions;
using RhSolutions.Api.Services;
namespace RhSolutions.Api.Middleware;
public class QueryModifier
{
private RequestDelegate _next;
public QueryModifier(RequestDelegate nextDelegate)
{
_next = nextDelegate;
}
public async Task Invoke(HttpContext context, IProductTypePredicter typePredicter, ProductQueryModifierFactory productQueryModifierFactory)
{
if (context.Request.Method == HttpMethods.Get
&& context.Request.Path == "/api/search")
{
string query = context.Request.Query["query"].ToString();
var productType = typePredicter.GetPredictedProductType(query);
var modifier = productQueryModifierFactory.GetModifier(productType!);
if (modifier.TryQueryModify(context.Request.Query, out var newQuery))
{
context.Request.QueryString = newQuery;
}
}
await _next(context);
}
}

View File

@ -1,6 +1,7 @@
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using RhSolutions.Models; using RhSolutions.Models;
using RhSolutions.Api.Services; using RhSolutions.Api.Services;
using RhSolutions.Api.Middleware;
var builder = WebApplication.CreateBuilder(args); var builder = WebApplication.CreateBuilder(args);
@ -21,12 +22,15 @@ builder.Services.AddDbContext<RhSolutionsContext>(opts =>
opts.EnableSensitiveDataLogging(true); opts.EnableSensitiveDataLogging(true);
} }
}); });
builder.Services.AddScoped<IPricelistParser, ClosedXMLParser>(); builder.Services.AddScoped<IPricelistParser, ClosedXMLParser>()
.AddScoped<IProductTypePredicter, ProductTypePredicter>()
.AddSingleton<ProductQueryModifierFactory>();
builder.Services.AddControllers(); builder.Services.AddControllers();
var app = builder.Build(); var app = builder.Build();
app.MapControllers(); app.MapControllers();
app.UseMiddleware<QueryModifier>();
var context = app.Services.CreateScope().ServiceProvider var context = app.Services.CreateScope().ServiceProvider
.GetRequiredService<RhSolutionsContext>(); .GetRequiredService<RhSolutionsContext>();

View File

@ -13,9 +13,16 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
</PackageReference> </PackageReference>
<PackageReference Include="Microsoft.ML" Version="2.0.1" />
<PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL" Version="7.0.4" /> <PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL" Version="7.0.4" />
<PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL.Design" Version="1.1.0" /> <PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL.Design" Version="1.1.0" />
<PackageReference Include="Rhsolutions.ProductSku" Version="1.0.0" /> <PackageReference Include="Rhsolutions.ProductSku" Version="1.0.0" />
</ItemGroup> </ItemGroup>
<ItemGroup>
<None Update="MLModels\model.zip">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project> </Project>

View File

@ -0,0 +1,11 @@
namespace RhSolutions.Api.Services
{
public class BypassQueryModifier : IProductQueryModifier
{
public bool TryQueryModify(IQueryCollection collection, out QueryString queryString)
{
queryString = QueryString.Empty;
return false;
}
}
}

View File

@ -0,0 +1,7 @@
namespace RhSolutions.Api.Services
{
public interface IProductQueryModifier
{
public bool TryQueryModify(IQueryCollection collection, out QueryString queryString);
}
}

View File

@ -0,0 +1,6 @@
namespace RhSolutions.Api.Services;
public interface IProductTypePredicter
{
public string? GetPredictedProductType(string productName);
}

View File

@ -0,0 +1,15 @@
namespace RhSolutions.Api.Services;
public class ProductQueryModifierFactory
{
public IProductQueryModifier GetModifier(string productTypeName)
{
switch (productTypeName)
{
case "Тройник RAUTITAN":
return new TPieceQueryModifier();
default:
return new BypassQueryModifier();
}
}
}

View File

@ -0,0 +1,44 @@
using Microsoft.ML;
using Microsoft.ML.Data;
namespace RhSolutions.Api.Services;
public class ProductTypePredicter : IProductTypePredicter
{
private readonly string _modelPath = @"./MLModels/model.zip";
private MLContext _mlContext;
private ITransformer _loadedModel;
private PredictionEngine<Product, TypePrediction> _predEngine;
public ProductTypePredicter()
{
_mlContext = new MLContext(seed: 0);
_loadedModel = _mlContext.Model.Load(_modelPath, out var _);
_predEngine = _mlContext.Model.CreatePredictionEngine<Product, TypePrediction>(_loadedModel);
}
public string? GetPredictedProductType(string productName)
{
Product p = new()
{
Name = productName
};
var prediction = _predEngine.Predict(p);
return prediction.Type;
}
public class Product
{
[LoadColumn(0)]
public string? Name { get; set; }
[LoadColumn(1)]
public string? Type { get; set; }
}
public class TypePrediction
{
[ColumnName("PredictedLabel")]
public string? Type { get; set; }
}
}

View File

@ -0,0 +1,42 @@
using Microsoft.AspNetCore.Http.Extensions;
using System.Text;
using System.Text.RegularExpressions;
namespace RhSolutions.Api.Services
{
public class TPieceQueryModifier : IProductQueryModifier
{
private readonly string pattern = @"(\b16|20|25|32|40|50|63\b)+";
public bool TryQueryModify(IQueryCollection collection, out QueryString queryString)
{
queryString = QueryString.Empty;
var query = collection["query"].ToString();
if (string.IsNullOrEmpty(query))
{
return false;
}
var matches = Regex.Matches(query, pattern);
StringBuilder sb = new();
sb.Append("Тройник RAUTITAN -PLATINUM");
if (matches.Count == 1)
{
sb.Append($" {matches.First().Value}-{matches.First().Value}-{matches.First().Value}");
}
else if (matches.Count >= 3)
{
sb.Append($" {matches[0].Value}-{matches[1].Value}-{matches[2].Value}");
}
else
{
return false;
}
QueryBuilder qb = new()
{
{ "query", sb.ToString() }
};
queryString = qb.ToQueryString();
return true;
}
}
}