How to wrap Entity Framework to intercept the LINQ expression just before execution?

14,771

Solution 1

Based on the answer by Arthur I've create a working wrapper.

The snippets provided provide a way to wrap each LINQ query with your own QueryProvider and IQueryable root. This would mean that you've got to have control over the initial query starting (as you'll have most of the time using any sort of pattern).

The problem with this method is that it's not transparent, a more ideal situation would be to inject something in the entities container at the constructor level.

I've created a compilable the implementation, got it to work with entity framework, and added support for the ObjectQuery.Include method. The expression visitor class can be copied from MSDN.

public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression expression = null;
    private QueryTranslatorProvider<T> provider = null;

    public QueryTranslator(IQueryable source)
    {
        expression = Expression.Constant(this);
        provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        expression = e;
        provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)provider.ExecuteEnumerable(this.expression)).GetEnumerator();
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return provider.ExecuteEnumerable(this.expression).GetEnumerator();
    }

    public QueryTranslator<T> Include(String path)
    {
        ObjectQuery<T> possibleObjectQuery = provider.source as ObjectQuery<T>;
        if (possibleObjectQuery != null)
        {
            return new QueryTranslator<T>(possibleObjectQuery.Include(path));
        }
        else
        {
            throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression");
        }
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return expression; }
    }

    public IQueryProvider Provider
    {
        get { return provider; }
    }
}

public class QueryTranslatorProvider<T> : ExpressionVisitor, IQueryProvider
{
    internal IQueryable source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        this.source = source;
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        return new QueryTranslator<TElement>(source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        Type elementType = expression.Type.GetGenericArguments().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.Execute(translated);
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.CreateQuery(translated);
    }

    #region Visitors
    protected override Expression VisitConstant(ConstantExpression c)
    {
        // fix up the Expression tree to work with EF again
        if (c.Type == typeof(QueryTranslator<T>))
        {
            return source.Expression;
        }
        else
        {
            return base.VisitConstant(c);
        }
    }
    #endregion
}

Example usage in your repository:

public IQueryable<User> List()
{
    return new QueryTranslator<User>(entities.Users).Include("Department");
}

Solution 2

I have exactly the sourcecode you'll need - but no idea how to attach a File.

Here are some snippets (snippets! I had to adapt this code, so it may not compile):

IQueryable:

public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression _expression = null;
    private QueryTranslatorProvider<T> _provider = null;

    public QueryTranslator(IQueryable source)
    {
        _expression = Expression.Constant(this);
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        _expression = e;
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)_provider.ExecuteEnumerable(this._expression)).GetEnumerator();
    }

    IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return _provider.ExecuteEnumerable(this._expression).GetEnumerator();
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return _expression; }
    }

    public IQueryProvider Provider
    {
        get { return _provider; }
    }
}

IQueryProvider:

public class QueryTranslatorProvider<T> : ExpressionTreeTranslator, IQueryProvider
{
    IQueryable _source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        _source = source;
    }

    #region IQueryProvider Members

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        return new QueryTranslator<TElement>(_source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Type elementType = expression.Type.FindElementTypes().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { _source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.Execute(translated);            
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.CreateQuery(translated);
    }

    #endregion        

    #region Visits
    protected override MethodCallExpression VisitMethodCall(MethodCallExpression m)
    {
        return m;
    }

    protected override Expression VisitUnary(UnaryExpression u)
    {
         return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method);
    }
    #endregion
}

Usage (warning: adapted code! May not compile):

private Dictionary<Type, object> _table = new Dictionary<Type, object>();
public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateQuery<T>("[" + typeof(T).Name + "]"));
    }

    return (IQueryable<T>)_table[type];
}

Expression Visitors/Translator:

http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx

http://msdn.microsoft.com/en-us/library/bb882521.aspx

EDIT: Added FindElementTypes(). Hopefully all Methods are present now.

    /// <summary>
    /// Finds all implemented IEnumerables of the given Type
    /// </summary>
    public static IQueryable<Type> FindIEnumerables(this Type seqType)
    {
        if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
            return new Type[] { }.AsQueryable();

        if (seqType.IsArray || seqType == typeof(IEnumerable))
            return new Type[] { typeof(IEnumerable) }.AsQueryable();

        if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
        {
            return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
        }

        var result = new List<Type>();

        foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
        {
            result.AddRange(FindIEnumerables(iface));
        }

        return FindIEnumerables(seqType.BaseType).Union(result);
    }

    /// <summary>
    /// Finds all element types provided by a specified sequence type.
    /// "Element types" are T for IEnumerable&lt;T&gt; and object for IEnumerable.
    /// </summary>
    public static IQueryable<Type> FindElementTypes(this Type seqType)
    {
        return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
    }

Solution 3

Just wanted to add to Arthur's example.

As Arthur warned there is a bug in his GetObjectQuery() method.

It creates the base Query using typeof(T).Name as the name of the EntitySet.

The EntitySet name is quite distinct from the type name.

If you are using EF 4 you should do this:

public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateObjectSet<T>();
    }

    return (IQueryable<T>)_table[type];
}

Which works so long as you don't have Multiple Entity Sets per Type (MEST) which is very rare.

If you are using 3.5 you can use the code in my Tip 13 to get the EntitySet name and feed it in like this:

public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateQuery<T>("[" + GetEntitySetName<T>() + "]"));

    } 
    return (IQueryable<T>)_table[type];
}

Hope this helps

Alex

Entity Framework Tips

Share:
14,771

Related videos on Youtube

Naveed
Author by

Naveed

I am a proud father, the co-founder and CEO of SWAT.engineering, an enthusiastic Software Engineer, and have a PhD in Software Engineering research. I contribute to open source when luxury permits, and am one of the core-developers of the Rascal - Meta Programming Language. Languages I learned and still use: Rasal-MPL, Java, C, Bash, Javascript, R, C#, Ruby, Python, C++, Matlab, LaTeX. (This list should be considered outdated when you're reading this ;-))

Updated on April 17, 2022

Comments

  • Naveed
    Naveed about 2 years

    I want to rewrite certain parts of the LINQ expression just before execution. And I'm having problems injecting my rewriter in the correct place (at all actually).

    Looking at the Entity Framework source (in reflector) it in the end comes down to the IQueryProvider.Execute which in EF is coupled to the expression by the ObjectContext offering the internal IQueryProvider Provider { get; } property.

    So I created a a wrapper class (implementing IQueryProvider) to do the Expression rewriting when the Execute gets called and then pass it to the original Provider.

    Problem is, the field behind Provider is private ObjectQueryProvider _queryProvider;. This ObjectQueryProvider is an internal sealed class, meaning it's not possible to create a subclass offering the added rewriting.

    So this approach got me to a dead end due to the very tightly coupled ObjectContext.

    How to solve this problem? Am I looking in the wrong direction? Is there perhaps a way to inject myself around this ObjectQueryProvider?

    Update: While the provided solutions all work when you're "wrapping" the ObjectContext using the Repository pattern, a solution which would allow for direct usage of the generated subclass from ObjectContext would be preferable. Hereby remaining compatible with the Dynamic Data scaffolding.

  • Naveed
    Naveed over 14 years
    If I'm not mistaken I would have to change the Sets from the generated subclass of ObjectContext to use this wrapper instead of the base.CreateQuery call? This is not exactly a nice solution, because a regenerate destroys my changes? Or am I misinterpreting your usage example?
  • Naveed
    Naveed over 14 years
    Hi, could you provide the "ExpressionTreeTranslator"? I'm guessing it's a implementation of the Expression Tree visitor pattern?
  • Arthur
    Arthur over 14 years
    @First comment: right, you wrap CreateQuery calls. I have my own generator so I have no troubles. I also have a own generic GetQuery Method which creates the correct EF Query and wraps it. I'll post that method. @Second: You can find the QueryTranslator here: msdn.microsoft.com/en-us/library/bb882521.aspx or here blogs.msdn.com/mattwar/archive/2007/07/31/…
  • Naveed
    Naveed over 14 years
    Okay, so it was just a rename of those classes, that's what I thought.. But could you provide the extension method called FindElementTypes() ? Can't find that one either using google.
  • Naveed
    Naveed over 14 years
    I'm sorry, but I can't get this to actually work, the EF providor does not like the query with the QueryTranslater wrapped around it. -- System.NotSupportedException: Unable to create a constant value of type 'QueryTranslator`1'. Only primitive types ('such as Int32, String, and Guid') are supported in this context..
  • Arthur
    Arthur over 14 years
    Do you have everything you need now? Should I provide some more Helper Methods or look something up in my code?
  • Naveed
    Naveed over 14 years
    No I got it to work, but strange that you left out the part to fix it back to a EF query.