KotlinCoroutineFilter.java

/*******************************************************************************
 * Copyright (c) 2009, 2024 Mountainminds GmbH & Co. KG and Contributors
 * This program and the accompanying materials are made available under
 * the terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 * Contributors:
 *    Evgeny Mandrikov - initial API and implementation
 *
 *******************************************************************************/
package org.jacoco.core.internal.analysis.filter;

import java.util.ArrayList;
import java.util.List;

import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.JumpInsnNode;
import org.objectweb.asm.tree.LdcInsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TableSwitchInsnNode;

/**
 * Filters branches that Kotlin compiler generates for coroutines.
 */
public final class KotlinCoroutineFilter implements IFilter {

	static boolean isImplementationOfSuspendFunction(
			final MethodNode methodNode) {
		if (methodNode.name.startsWith("access$")) {
			return false;
		}
		final Type methodType = Type.getMethodType(methodNode.desc);
		final int lastArgument = methodType.getArgumentTypes().length - 1;
		return lastArgument >= 0 && "kotlin.coroutines.Continuation".equals(
				methodType.getArgumentTypes()[lastArgument].getClassName());
	}

	public void filter(final MethodNode methodNode,
			final IFilterContext context, final IFilterOutput output) {

		if (!KotlinGeneratedFilter.isKotlinClass(context)) {
			return;
		}

		new Matcher().match(methodNode, output);
		new Matcher().matchOptimizedTailCall(methodNode, output);
	}

	private static class Matcher extends AbstractMatcher {

		private void matchOptimizedTailCall(final MethodNode methodNode,
				final IFilterOutput output) {
			for (final AbstractInsnNode i : methodNode.instructions) {
				cursor = i;
				nextIs(Opcodes.DUP);
				nextIsInvoke(Opcodes.INVOKESTATIC,
						"kotlin/coroutines/intrinsics/IntrinsicsKt",
						"getCOROUTINE_SUSPENDED", "()Ljava/lang/Object;");
				nextIs(Opcodes.IF_ACMPNE);
				nextIs(Opcodes.ARETURN);
				nextIs(Opcodes.POP);
				if (cursor != null) {
					output.ignore(i.getNext(), cursor);
				}
			}
		}

		private void match(final MethodNode methodNode,
				final IFilterOutput output) {
			cursor = skipNonOpcodes(methodNode.instructions.getFirst());
			if (cursor == null || cursor.getOpcode() != Opcodes.INVOKESTATIC) {
				cursor = null;
			} else {
				final MethodInsnNode m = (MethodInsnNode) cursor;
				if (!"kotlin/coroutines/intrinsics/IntrinsicsKt".equals(m.owner)
						|| !"getCOROUTINE_SUSPENDED".equals(m.name)
						|| !"()Ljava/lang/Object;".equals(m.desc)) {
					cursor = null;
				}
			}

			if (cursor == null) {
				cursor = skipNonOpcodes(methodNode.instructions.getFirst());

				nextIsCreateStateInstance();

				nextIsInvoke(Opcodes.INVOKESTATIC,
						"kotlin/coroutines/intrinsics/IntrinsicsKt",
						"getCOROUTINE_SUSPENDED", "()Ljava/lang/Object;");
			}

			nextIsVar(Opcodes.ASTORE, "COROUTINE_SUSPENDED");
			nextIsVar(Opcodes.ALOAD, "this");
			nextIs(Opcodes.GETFIELD);
			nextIs(Opcodes.TABLESWITCH);
			if (cursor == null) {
				return;
			}
			final TableSwitchInsnNode s = (TableSwitchInsnNode) cursor;
			final List<AbstractInsnNode> ignore = new ArrayList<AbstractInsnNode>(
					s.labels.size() * 2);

			nextIs(Opcodes.ALOAD);
			nextIsThrowOnFailure();

			if (cursor == null) {
				return;
			}
			ignore.add(methodNode.instructions.getFirst());
			ignore.add(cursor);

			int suspensionPoint = 1;
			for (AbstractInsnNode i = cursor; i != null
					&& suspensionPoint < s.labels.size(); i = i.getNext()) {
				cursor = i;
				nextIsVar(Opcodes.ALOAD, "COROUTINE_SUSPENDED");
				nextIs(Opcodes.IF_ACMPNE);
				if (cursor == null) {
					continue;
				}
				final AbstractInsnNode continuationAfterLoadedResult = skipNonOpcodes(
						((JumpInsnNode) cursor).label);
				nextIsVar(Opcodes.ALOAD, "COROUTINE_SUSPENDED");
				nextIs(Opcodes.ARETURN);
				if (cursor == null
						|| skipNonOpcodes(cursor.getNext()) != skipNonOpcodes(
								s.labels.get(suspensionPoint))) {
					continue;
				}

				for (AbstractInsnNode j = i; j != null; j = j.getNext()) {
					cursor = j;
					nextIs(Opcodes.ALOAD);
					nextIsThrowOnFailure();

					nextIs(Opcodes.ALOAD);
					if (cursor != null && skipNonOpcodes(cursor
							.getNext()) == continuationAfterLoadedResult) {
						ignore.add(i);
						ignore.add(cursor);
						suspensionPoint++;
						break;
					}
				}
			}

			cursor = s.dflt;
			nextIsType(Opcodes.NEW, "java/lang/IllegalStateException");
			nextIs(Opcodes.DUP);
			nextIs(Opcodes.LDC);
			if (cursor == null) {
				return;
			}
			if (!((LdcInsnNode) cursor).cst.equals(
					"call to 'resume' before 'invoke' with coroutine")) {
				return;
			}
			nextIsInvoke(Opcodes.INVOKESPECIAL,
					"java/lang/IllegalStateException", "<init>",
					"(Ljava/lang/String;)V");
			nextIs(Opcodes.ATHROW);
			if (cursor == null) {
				return;
			}

			output.ignore(s.dflt, cursor);
			for (int i = 0; i < ignore.size(); i += 2) {
				output.ignore(ignore.get(i), ignore.get(i + 1));
			}
		}

		private void nextIsThrowOnFailure() {
			final AbstractInsnNode c = cursor;
			nextIsInvoke(Opcodes.INVOKESTATIC, "kotlin/ResultKt",
					"throwOnFailure", "(Ljava/lang/Object;)V");
			if (cursor == null) {
				cursor = c;
				// Before resolution of
				// https://youtrack.jetbrains.com/issue/KT-28015 in
				// Kotlin 1.3.30
				nextIs(Opcodes.DUP);
				nextIsType(Opcodes.INSTANCEOF, "kotlin/Result$Failure");
				nextIs(Opcodes.IFEQ);
				nextIsType(Opcodes.CHECKCAST, "kotlin/Result$Failure");
				nextIs(Opcodes.GETFIELD);
				nextIs(Opcodes.ATHROW);
				nextIs(Opcodes.POP);
			}
		}

		private void nextIsCreateStateInstance() {
			nextIs(Opcodes.INSTANCEOF);

			nextIs(Opcodes.IFEQ);
			if (cursor == null) {
				return;
			}
			final AbstractInsnNode createStateInstance = skipNonOpcodes(
					((JumpInsnNode) cursor).label);

			nextIs(Opcodes.ALOAD);
			nextIs(Opcodes.CHECKCAST);
			nextIs(Opcodes.ASTORE);

			nextIs(Opcodes.ALOAD);
			nextIs(Opcodes.GETFIELD);

			nextIs(Opcodes.LDC);
			nextIs(Opcodes.IAND);
			nextIs(Opcodes.IFEQ);
			if (cursor == null || skipNonOpcodes(
					((JumpInsnNode) cursor).label) != createStateInstance) {
				return;
			}

			nextIs(Opcodes.ALOAD);
			nextIs(Opcodes.DUP);
			nextIs(Opcodes.GETFIELD);

			nextIs(Opcodes.LDC);
			nextIs(Opcodes.ISUB);
			nextIs(Opcodes.PUTFIELD);

			nextIs(Opcodes.GOTO);
			if (cursor == null) {
				return;
			}
			final AbstractInsnNode afterCoroutineStateCreated = skipNonOpcodes(
					((JumpInsnNode) cursor).label);

			if (skipNonOpcodes(cursor.getNext()) != createStateInstance) {
				return;
			}

			cursor = afterCoroutineStateCreated;
			nextIs(Opcodes.GETFIELD);
			nextIs(Opcodes.ASTORE);
		}
	}

}